Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/maxdiffusion/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,15 @@
[CROSS_ATTN_Q_LENGTH, CONTEXT],
[CROSS_ATTN_KV_LENGTH, CONTEXT],
]

### Common axis rules for 2D Ulysses + ring attention ###
# Public configs shard sequence on `context`; attention code privately reshapes
# that axis into hidden ring and Ulysses axes for the hybrid kernel.
ULYSSES_RING_ATTENTION_AXIS_RULES = [
[SELF_ATTN_HEAD, None],
[SELF_ATTN_Q_LENGTH, CONTEXT],
[SELF_ATTN_KV_LENGTH, CONTEXT],
[CROSS_ATTN_HEAD, None],
[CROSS_ATTN_Q_LENGTH, CONTEXT],
[CROSS_ATTN_KV_LENGTH, None],
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 This axis rule shards the cross-attention KV length as None (replicated), which matches SEQUENCE_PARALLEL_AXIS_RULES. Given that ulysses_ring falls back to tokamax_flash for cross-attention, this is consistent and ensures compatibility when the encoder sequence is not sharded across the context mesh.

Suggested change
[CROSS_ATTN_KV_LENGTH, None],
[CROSS_ATTN_KV_LENGTH, None],

]
4 changes: 3 additions & 1 deletion src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,11 @@ jit_initializers: True
# Set true to load weights from pytorch
from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom, ulysses_ring
use_base2_exp: True
use_experimental_scheduler: True
# For attention=ulysses_ring, hidden Ulysses shard count; ring shards are context / this.
ulysses_shards: -1
flash_min_seq_length: 4096
dropout: 0.0

Expand Down
4 changes: 3 additions & 1 deletion src/maxdiffusion/configs/base_wan_1_3b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,11 @@ jit_initializers: True
# Set true to load weights from pytorch
from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom, ulysses_ring
use_base2_exp: True
use_experimental_scheduler: True
# For attention=ulysses_ring, hidden Ulysses shard count; ring shards are context / this.
ulysses_shards: -1
flash_min_seq_length: 0

# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
Expand Down
20 changes: 11 additions & 9 deletions src/maxdiffusion/configs/base_wan_27b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,11 @@ jit_initializers: True
# Set true to load weights from pytorch
from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom, ulysses_ring
use_base2_exp: True
use_experimental_scheduler: True
# For attention=ulysses_ring, hidden Ulysses shard count; ring shards are context / this.
ulysses_shards: -1
flash_min_seq_length: 4096
dropout: 0.0

Expand All @@ -81,14 +83,14 @@ mask_padding_tokens: True
attention_sharding_uniform: True

flash_block_sizes: {
"block_q" : 512,
"block_kv_compute" : 512,
"block_kv" : 512,
"block_q_dkv" : 512,
"block_kv_dkv" : 512,
"block_kv_dkv_compute" : 512,
"block_q_dq" : 512,
"block_kv_dq" : 512,
"block_q" : 2048,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 block_q and block_kv were increased from 512 to 2048 in this config. Is this change required for the new ulysses_ring attention kernel or is it a general performance optimization? It might be worth documenting why this specific model size received this update, or applying it to others if it's a general improvement.

"block_kv_compute" : 1024,
"block_kv" : 2048,
"block_q_dkv" : 2048,
"block_kv_dkv" : 2048,
"block_kv_dkv_compute" : 1024,
"block_q_dq" : 2048,
"block_kv_dq" : 2048,
"use_fused_bwd_kernel": False,
}
# Use on v6e
Expand Down
4 changes: 3 additions & 1 deletion src/maxdiffusion/configs/base_wan_animate.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,11 @@ jit_initializers: True
# Set true to load weights from pytorch
from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom, ulysses_ring
use_base2_exp: True
use_experimental_scheduler: True
# For attention=ulysses_ring, hidden Ulysses shard count; ring shards are context / this.
ulysses_shards: -1
flash_min_seq_length: 4096
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
Expand Down
4 changes: 3 additions & 1 deletion src/maxdiffusion/configs/base_wan_i2v_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,11 @@ jit_initializers: True
# Set true to load weights from pytorch
from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom, ulysses_ring
use_base2_exp: True
use_experimental_scheduler: True
# For attention=ulysses_ring, hidden Ulysses shard count; ring shards are context / this.
ulysses_shards: -1
flash_min_seq_length: 4096
dropout: 0.0

Expand Down
4 changes: 3 additions & 1 deletion src/maxdiffusion/configs/base_wan_i2v_27b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,11 @@ jit_initializers: True
# Set true to load weights from pytorch
from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom, ulysses_ring
use_base2_exp: True
use_experimental_scheduler: True
# For attention=ulysses_ring, hidden Ulysses shard count; ring shards are context / this.
ulysses_shards: -1
flash_min_seq_length: 4096
dropout: 0.0

Expand Down
2 changes: 1 addition & 1 deletion src/maxdiffusion/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ def get_flash_block_sizes(config):
"""Create custom flash attention BlockSizes."""
flash_block_sizes = None
if len(config.flash_block_sizes.keys()) > 0:
attention_is_tokamax = "tokamax" in config.attention
attention_is_tokamax = "tokamax" in config.attention or config.attention == "ulysses_ring"
user_block_sizes: Dict[str, int] = config.flash_block_sizes
if attention_is_tokamax:
max_logging.log(
Expand Down
Loading
Loading