diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 5a5cfa29..fb623c07 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -676,7 +676,15 @@ def _create_common_components(cls, config, vae_only=False, i2v=False): dp_size = mesh.shape.get("data", 1) if dp_size == -1 or dp_size == 0: dp_size = 1 - vae_spatial = (2 * total_devices) // dp_size + non_dp_size = total_devices // dp_size + # VAE activations are huge; we want to maximize VAE spatial sharding (vae_spatial). + # Ideally, we want 2x the Transformer's model parallel size (2 * non_dp_size). + # However, vae_spatial must divide total_devices, which requires dp_size to be even. + # If dp_size is odd (e.g., single-device runs), we fallback to 1x (non_dp_size). + if dp_size % 2 == 0: + vae_spatial = 2 * non_dp_size + else: + vae_spatial = non_dp_size assert ( total_devices % vae_spatial == 0 diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index 9c6c9125..19f3ec30 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -281,8 +281,8 @@ def user_init(raw_keys): raw_keys["global_batch_size_to_train_on"], ) = _HyperParameters.calculate_global_batch_sizes(raw_keys["per_device_batch_size"]) - if raw_keys.get("vae_spatial", -1) == -1: - raw_keys["vae_spatial"] = 1 + if "vae_spatial" not in raw_keys: + raw_keys["vae_spatial"] = -1 def get_num_slices(raw_keys):