Skip to content

Fix VAE spatial sharding dynamic calculation bug in Wan pipeline.#407

Open
ninatu wants to merge 1 commit into
mainfrom
ninatu/wan_vae_sharding_bug
Open

Fix VAE spatial sharding dynamic calculation bug in Wan pipeline.#407
ninatu wants to merge 1 commit into
mainfrom
ninatu/wan_vae_sharding_bug

Conversation

@ninatu
Copy link
Copy Markdown
Collaborator

@ninatu ninatu commented May 15, 2026

Previously, setting vae_spatial: -1 in the config (intended to trigger dynamic calculation of the VAE spatial sharding axis size) was ineffective because pyconfig.py prematurely overrode any -1 or missing vae_spatial value to 1.

Furthermore, the dynamic calculation formula in wan_pipeline.py (vae_spatial = (2 * total_devices) // dp_size) was not robust. On single-device runs (where total_devices=1 and dp_size=1) or configurations with odd data parallel (DP) sizes, it would calculate a vae_spatial value (e.g., 2) that does not divide total_devices, failing the mesh validation assertion.

Previously, setting `vae_spatial: -1` in the config (intended to trigger dynamic calculation of the VAE spatial sharding axis size) was ineffective because `pyconfig.py` prematurely overrode any `-1` or missing `vae_spatial` value to `1`.

Furthermore, the dynamic calculation formula in `wan_pipeline.py` (`vae_spatial = (2 * total_devices) // dp_size`) was not robust. On single-device runs (where `total_devices=1` and `dp_size=1`) or
  configurations with odd data parallel (DP) sizes, it would calculate a `vae_spatial` value (e.g., 2) that does not divide `total_devices`, failing the mesh validation assertion.
@ninatu ninatu requested a review from entrpn as a code owner May 15, 2026 14:47
@github-actions
Copy link
Copy Markdown

@Perseus14 Perseus14 requested a review from eltsai May 15, 2026 15:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants