From 69cbe9128b13d77d933c150f4731e7f0ab1fbf46 Mon Sep 17 00:00:00 2001 From: Nina Shvetsova Date: Fri, 15 May 2026 13:47:47 +0000 Subject: [PATCH] Fix VAE spatial sharding dynamic calculation bug in Wan pipeline. Previously, setting `vae_spatial: -1` in the config (intended to trigger dynamic calculation of the VAE spatial sharding axis size) was ineffective because `pyconfig.py` prematurely overrode any `-1` or missing `vae_spatial` value to `1`. Furthermore, the dynamic calculation formula in `wan_pipeline.py` (`vae_spatial = (2 * total_devices) // dp_size`) was not robust. On single-device runs (where `total_devices=1` and `dp_size=1`) or configurations with odd data parallel (DP) sizes, it would calculate a `vae_spatial` value (e.g., 2) that does not divide `total_devices`, failing the mesh validation assertion. --- src/maxdiffusion/pipelines/wan/wan_pipeline.py | 10 +++++++++- src/maxdiffusion/pyconfig.py | 4 ++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 5a5cfa29..fb623c07 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -676,7 +676,15 @@ def _create_common_components(cls, config, vae_only=False, i2v=False): dp_size = mesh.shape.get("data", 1) if dp_size == -1 or dp_size == 0: dp_size = 1 - vae_spatial = (2 * total_devices) // dp_size + non_dp_size = total_devices // dp_size + # VAE activations are huge; we want to maximize VAE spatial sharding (vae_spatial). + # Ideally, we want 2x the Transformer's model parallel size (2 * non_dp_size). + # However, vae_spatial must divide total_devices, which requires dp_size to be even. + # If dp_size is odd (e.g., single-device runs), we fallback to 1x (non_dp_size). + if dp_size % 2 == 0: + vae_spatial = 2 * non_dp_size + else: + vae_spatial = non_dp_size assert ( total_devices % vae_spatial == 0 diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index 9c6c9125..19f3ec30 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -281,8 +281,8 @@ def user_init(raw_keys): raw_keys["global_batch_size_to_train_on"], ) = _HyperParameters.calculate_global_batch_sizes(raw_keys["per_device_batch_size"]) - if raw_keys.get("vae_spatial", -1) == -1: - raw_keys["vae_spatial"] = 1 + if "vae_spatial" not in raw_keys: + raw_keys["vae_spatial"] = -1 def get_num_slices(raw_keys):