diff --git a/src/maxdiffusion/common_types.py b/src/maxdiffusion/common_types.py index a92d5ec3..9cb75947 100644 --- a/src/maxdiffusion/common_types.py +++ b/src/maxdiffusion/common_types.py @@ -95,3 +95,15 @@ [CROSS_ATTN_Q_LENGTH, CONTEXT], [CROSS_ATTN_KV_LENGTH, CONTEXT], ] + +### Common axis rules for 2D Ulysses + ring attention ### +# Public configs shard sequence on `context`; attention code privately reshapes +# that axis into hidden ring and Ulysses axes for the hybrid kernel. +ULYSSES_RING_ATTENTION_AXIS_RULES = [ + [SELF_ATTN_HEAD, None], + [SELF_ATTN_Q_LENGTH, CONTEXT], + [SELF_ATTN_KV_LENGTH, CONTEXT], + [CROSS_ATTN_HEAD, None], + [CROSS_ATTN_Q_LENGTH, CONTEXT], + [CROSS_ATTN_KV_LENGTH, None], +] diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index f432928a..60c40f31 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -64,9 +64,11 @@ jit_initializers: True # Set true to load weights from pytorch from_pt: True split_head_dim: True -attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom +attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom, ulysses_ring use_base2_exp: True use_experimental_scheduler: True +# For attention=ulysses_ring, hidden Ulysses shard count; ring shards are context / this. +ulysses_shards: -1 flash_min_seq_length: 4096 dropout: 0.0 diff --git a/src/maxdiffusion/configs/base_wan_1_3b.yml b/src/maxdiffusion/configs/base_wan_1_3b.yml index 0e055265..cbd76e16 100644 --- a/src/maxdiffusion/configs/base_wan_1_3b.yml +++ b/src/maxdiffusion/configs/base_wan_1_3b.yml @@ -60,9 +60,11 @@ jit_initializers: True # Set true to load weights from pytorch from_pt: True split_head_dim: True -attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses +attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom, ulysses_ring use_base2_exp: True use_experimental_scheduler: True +# For attention=ulysses_ring, hidden Ulysses shard count; ring shards are context / this. +ulysses_shards: -1 flash_min_seq_length: 0 # If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens. diff --git a/src/maxdiffusion/configs/base_wan_27b.yml b/src/maxdiffusion/configs/base_wan_27b.yml index bf29fa86..909fe00b 100644 --- a/src/maxdiffusion/configs/base_wan_27b.yml +++ b/src/maxdiffusion/configs/base_wan_27b.yml @@ -64,9 +64,11 @@ jit_initializers: True # Set true to load weights from pytorch from_pt: True split_head_dim: True -attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom +attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom, ulysses_ring use_base2_exp: True use_experimental_scheduler: True +# For attention=ulysses_ring, hidden Ulysses shard count; ring shards are context / this. +ulysses_shards: -1 flash_min_seq_length: 4096 dropout: 0.0 @@ -81,14 +83,14 @@ mask_padding_tokens: True attention_sharding_uniform: True flash_block_sizes: { - "block_q" : 512, - "block_kv_compute" : 512, - "block_kv" : 512, - "block_q_dkv" : 512, - "block_kv_dkv" : 512, - "block_kv_dkv_compute" : 512, - "block_q_dq" : 512, - "block_kv_dq" : 512, + "block_q" : 2048, + "block_kv_compute" : 1024, + "block_kv" : 2048, + "block_q_dkv" : 2048, + "block_kv_dkv" : 2048, + "block_kv_dkv_compute" : 1024, + "block_q_dq" : 2048, + "block_kv_dq" : 2048, "use_fused_bwd_kernel": False, } # Use on v6e diff --git a/src/maxdiffusion/configs/base_wan_animate.yml b/src/maxdiffusion/configs/base_wan_animate.yml index 8f95c855..4d98b267 100644 --- a/src/maxdiffusion/configs/base_wan_animate.yml +++ b/src/maxdiffusion/configs/base_wan_animate.yml @@ -62,9 +62,11 @@ jit_initializers: True # Set true to load weights from pytorch from_pt: True split_head_dim: True -attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom +attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom, ulysses_ring use_base2_exp: True use_experimental_scheduler: True +# For attention=ulysses_ring, hidden Ulysses shard count; ring shards are context / this. +ulysses_shards: -1 flash_min_seq_length: 4096 # If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens. # Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster. diff --git a/src/maxdiffusion/configs/base_wan_i2v_14b.yml b/src/maxdiffusion/configs/base_wan_i2v_14b.yml index ca2d239a..333134be 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_14b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_14b.yml @@ -64,9 +64,11 @@ jit_initializers: True # Set true to load weights from pytorch from_pt: True split_head_dim: True -attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom +attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom, ulysses_ring use_base2_exp: True use_experimental_scheduler: True +# For attention=ulysses_ring, hidden Ulysses shard count; ring shards are context / this. +ulysses_shards: -1 flash_min_seq_length: 4096 dropout: 0.0 diff --git a/src/maxdiffusion/configs/base_wan_i2v_27b.yml b/src/maxdiffusion/configs/base_wan_i2v_27b.yml index 90799524..1b90b612 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_27b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_27b.yml @@ -64,9 +64,11 @@ jit_initializers: True # Set true to load weights from pytorch from_pt: True split_head_dim: True -attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom +attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom, ulysses_ring use_base2_exp: True use_experimental_scheduler: True +# For attention=ulysses_ring, hidden Ulysses shard count; ring shards are context / this. +ulysses_shards: -1 flash_min_seq_length: 4096 dropout: 0.0 diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index 3a885fba..14166bd8 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -617,7 +617,7 @@ def get_flash_block_sizes(config): """Create custom flash attention BlockSizes.""" flash_block_sizes = None if len(config.flash_block_sizes.keys()) > 0: - attention_is_tokamax = "tokamax" in config.attention + attention_is_tokamax = "tokamax" in config.attention or config.attention == "ulysses_ring" user_block_sizes: Dict[str, int] = config.flash_block_sizes if attention_is_tokamax: max_logging.log( diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 6ff4b4fe..30cf2d2c 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -15,6 +15,7 @@ import contextlib import functools import math +import numpy as np from typing import Optional, Callable, Tuple, Dict import flax.linen as nn from flax import nnx @@ -61,6 +62,9 @@ CROSS_ATTN_Q_LENGTH = common_types.CROSS_ATTN_Q_LENGTH CROSS_ATTN_KV_LENGTH = common_types.CROSS_ATTN_KV_LENGTH +INTERNAL_RING_AXIS = "ring" +INTERNAL_ULYSSES_AXIS = "ulysses" + def _coerce_tokamax_block_sizes(block_sizes): # Tokamax requires fused bwd; convert if needed. @@ -159,6 +163,42 @@ def _unflatten_heads(tensor, heads): return tensor +def _replace_mesh_axis(axis_spec, old_axis: str, new_axes: tuple[str, ...]): + if axis_spec == old_axis: + return new_axes + if isinstance(axis_spec, tuple): + replacement = [] + for axis in axis_spec: + if axis == old_axis: + replacement.extend(new_axes) + else: + replacement.append(axis) + return tuple(replacement) + return axis_spec + + +def _replace_mesh_axis_names(axis_names, old_axis: str, new_axes: tuple[str, ...]): + return jax.sharding.PartitionSpec(*(_replace_mesh_axis(axis_name, old_axis, new_axes) for axis_name in axis_names)) + + +def _create_internal_ulysses_ring_mesh( + mesh: Mesh, + ring_shards: int, + ulysses_shards: int, + ring_axis: str = INTERNAL_RING_AXIS, + ulysses_axis: str = INTERNAL_ULYSSES_AXIS, +) -> Mesh: + """Split the public context mesh axis into private ring and Ulysses axes.""" + mesh_axis_names = tuple(mesh.axis_names) + context_axis_index = mesh_axis_names.index("context") + devices = mesh.devices + new_shape = devices.shape[:context_axis_index] + (ring_shards, ulysses_shards) + devices.shape[context_axis_index + 1 :] + new_axis_names = ( + mesh_axis_names[:context_axis_index] + (ring_axis, ulysses_axis) + mesh_axis_names[context_axis_index + 1 :] + ) + return Mesh(devices.reshape(new_shape), new_axis_names) + + def _reshape_data_for_flash(tensor, heads, num_context_shards=1): """ Reshapes tensors for pallas flash attention adding padding to both seq_len and head_dim. @@ -533,13 +573,12 @@ def _ulysses_attention( Tensors arrive sequence-sharded on the context axis. Inside a shard_map the all-to-all collectives trade sequence shards for head shards, run local - splash attention on the full sequence with a subset of heads, then all-to-all - back. + splash attention on the full sequence with a subset of heads, then + all-to-all back. """ axis_name = "context" num_shards = mesh.shape[axis_name] - # Reshape to [b, h, s, d] and pad sequence for even context-axis splitting. query, orig_q_seq_len = _reshape_data_for_flash(query, heads, num_shards) key, _ = _reshape_data_for_flash(key, heads, num_shards) value, _ = _reshape_data_for_flash(value, heads, num_shards) @@ -551,7 +590,6 @@ def _ulysses_attention( "Ulysses attention requires the number of heads to be divisible by the context shard count, " f"got heads={num_heads} and context_shards={num_shards}." ) - if not use_custom_kernel: block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, "flash") @@ -566,8 +604,8 @@ def _ulysses_attention( check_vma=False, ) def wrap_ulysses_attention(query, key, value): - # Swap sharding modes: each device gives up a slice of sequence and gathers - # a slice of heads, so the local splash kernel sees the full sequence. + # Swap sharding: each device gives up a slice of heads and gathers + # a slice of sequence, so the local kernel sees the full sequence. query = jax.lax.all_to_all(query, axis_name=axis_name, split_axis=1, concat_axis=2, tiled=True) key = jax.lax.all_to_all(key, axis_name=axis_name, split_axis=1, concat_axis=2, tiled=True) value = jax.lax.all_to_all(value, axis_name=axis_name, split_axis=1, concat_axis=2, tiled=True) @@ -677,11 +715,157 @@ def wrap_ulysses_attention(query, key, value): attention_output = vmapped_splash(query, key, value, segment_ids) attention_output = attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype) - # Restore the original layout expected by the rest of the model: - # head-sharded / full-sequence -> sequence-sharded / full-heads. + # Restore original layout: head-sharded/full-sequence -> sequence-sharded/full-heads. + attention_output = jax.lax.all_to_all(attention_output, axis_name=axis_name, split_axis=2, concat_axis=1, tiled=True) + return attention_output + + devices_in_batch_sharding = mesh.shape["data"] * (mesh.shape["fsdp"] if "fsdp" in mesh.shape else 1) + if not (query.shape[0] / devices_in_batch_sharding).is_integer(): + max_logging.log( + "Warning, batch dimension should be shardable among the devices in data and fsdp" + f" axis, batch dimension: {query.shape[0]}, devices_in_batch_sharding: {devices_in_batch_sharding}" + ) + x = wrap_ulysses_attention(query, key, value) + x = x[:, :, :orig_q_seq_len, :] + x = _reshape_heads_to_head_dim(x) + + return x + + +def _ulysses_ring_attention( + query: jax.Array, + key: jax.Array, + value: jax.Array, + heads: int, + mesh: Mesh, + axis_names_q: AxisNames, + axis_names_kv: AxisNames, + flash_block_sizes: BlockSizes, + dtype: jnp.dtype = jnp.float32, + mask_padding_tokens: bool = True, + residual_checkpoint_name: str | None = None, + attention_mask: jax.Array = None, + ulysses_axis: str = INTERNAL_ULYSSES_AXIS, + ring_axis: str = INTERNAL_RING_AXIS, + use_base2_exp: bool = False, + use_experimental_scheduler: bool = False, + ulysses_shards: int = -1, +) -> jax.Array: + """2D context-parallel attention using a private Ulysses x ring mesh. + + Public configs only shard sequence on the context axis. Internally this + reshapes that same device axis into hidden ring and Ulysses axes, runs the + Ulysses all-to-all over the hidden Ulysses axis, and rotates K/V over the + hidden ring axis. + """ + context_axis = "context" + if context_axis not in mesh.shape: + raise ValueError(f"Ulysses ring attention requires mesh axis {context_axis!r}, got mesh axes {mesh.shape}.") + + num_context_shards = mesh.shape[context_axis] + num_ulysses_shards = ulysses_shards + if num_ulysses_shards <= 0: + raise ValueError("Ulysses ring attention requires ulysses_shards to be set from config or command line.") + if num_context_shards % num_ulysses_shards != 0: + raise ValueError( + "Ulysses ring attention requires the requested Ulysses shard count to divide the context shard count, " + f"got context_shards={num_context_shards} and ulysses_shards={num_ulysses_shards}." + ) + if heads % num_ulysses_shards != 0: + raise ValueError( + "Ulysses ring attention requires the number of heads to be divisible by the requested Ulysses shard count, " + f"got heads={heads} and ulysses_shards={num_ulysses_shards}." + ) + num_ring_shards = num_context_shards // num_ulysses_shards + internal_mesh = _create_internal_ulysses_ring_mesh( + mesh, + ring_shards=num_ring_shards, + ulysses_shards=num_ulysses_shards, + ring_axis=ring_axis, + ulysses_axis=ulysses_axis, + ) + internal_sequence_axes = (ring_axis, ulysses_axis) + num_sequence_shards = num_context_shards + + query, orig_q_seq_len = _reshape_data_for_flash(query, heads, num_sequence_shards) + key, orig_kv_seq_len = _reshape_data_for_flash(key, heads, num_sequence_shards) + value, _ = _reshape_data_for_flash(value, heads, num_sequence_shards) + + block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, "tokamax_ring") + + q_axis_names = nn.logical_to_mesh_axes(axis_names_q) + kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv) + internal_q_axis_names = _replace_mesh_axis_names(q_axis_names, context_axis, internal_sequence_axes) + internal_kv_axis_names = _replace_mesh_axis_names(kv_axis_names, context_axis, internal_sequence_axes) + + @functools.partial( + jax.shard_map, + mesh=internal_mesh, + in_specs=(internal_q_axis_names, internal_kv_axis_names, internal_kv_axis_names), + out_specs=internal_q_axis_names, + check_vma=False, + ) + def wrap_ulysses_ring_attention(query, key, value): + # Swap sharding: each device gives up a slice of heads and gathers + # a slice of sequence, so the local kernel sees the full sequence. + query = jax.lax.all_to_all(query, axis_name=ulysses_axis, split_axis=1, concat_axis=2, tiled=True) + key = jax.lax.all_to_all(key, axis_name=ulysses_axis, split_axis=1, concat_axis=2, tiled=True) + value = jax.lax.all_to_all(value, axis_name=ulysses_axis, split_axis=1, concat_axis=2, tiled=True) + + uses_fused_kernel = block_sizes.use_fused_bwd_kernel + block_q_sizes = (block_sizes.block_q, block_sizes.block_q_dkv) + block_kv_sizes = (block_sizes.block_kv, block_sizes.block_kv_dkv) + if uses_fused_kernel: + block_q_sizes += (block_sizes.block_q_dkv,) + block_kv_sizes += (block_sizes.block_kv_dkv,) + else: + block_q_sizes += (block_sizes.block_q_dq,) + block_kv_sizes += (block_sizes.block_kv_dq,) + + block_q = max(*block_q_sizes) + query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, block_q) + block_kv = max(*block_kv_sizes) + key, _, key_seq_len = _pad_data_for_flash(key, heads, block_kv) + value, _, _ = _pad_data_for_flash(value, heads, block_kv) + + q_padded_len = query.shape[2] + kv_padded_len = key.shape[2] + total_kv_len = kv_padded_len * num_ring_shards + + if mask_padding_tokens: + q_valid = np.arange(q_padded_len) < query_seq_len + kv_indices = np.arange(total_kv_len) + kv_ring_offsets = kv_indices % kv_padded_len + kv_ring_indices = kv_indices // kv_padded_len + kv_global_indices = kv_ring_indices * key_seq_len + kv_ring_offsets + kv_valid = (kv_ring_offsets < key_seq_len) & (kv_global_indices < orig_kv_seq_len) + mask = tokamax_splash_attention_mask.NumpyMask(q_valid[:, None] & kv_valid[None, :]) + else: + mask = tokamax_splash_attention_mask.FullMask( + _shape=(q_padded_len, total_kv_len), + ) + + splash_kernel = tokamax_ring_attention_kernel.make_ring_attention( + mask=mask, + is_mqa=False, + config=convert_to_tokamax_splash_config( + block_sizes, + residual_checkpoint_name=residual_checkpoint_name, + use_base2_exp=use_base2_exp, + use_experimental_scheduler=use_experimental_scheduler, + ), + save_residuals=False, + ring_axis=ring_axis, + kv_seq_shards=num_ring_shards, + ) + vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None)) + attention_output = vmapped_splash(query, key, value, None) + attention_output = attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype) + + # Restore original layout: head-sharded/full-sequence -> sequence-sharded/full-heads. attention_output = jax.lax.all_to_all( attention_output, - axis_name=axis_name, + axis_name=ulysses_axis, split_axis=2, concat_axis=1, tiled=True, @@ -694,7 +878,8 @@ def wrap_ulysses_attention(query, key, value): "Warning, batch dimension should be shardable among the devices in data and fsdp" f" axis, batch dimension: {query.shape[0]}, devices_in_batch_sharding: {devices_in_batch_sharding}" ) - x = wrap_ulysses_attention(query, key, value) + x = wrap_ulysses_ring_attention(query, key, value) + x = jax.lax.with_sharding_constraint(x, q_axis_names) x = x[:, :, :orig_q_seq_len, :] x = _reshape_heads_to_head_dim(x) @@ -865,6 +1050,27 @@ def ulysses_kernel(q, k, v, context): ) +@register_kernel("ulysses_ring") +def ulysses_ring_kernel(q, k, v, context): + return _ulysses_ring_attention( + q, + k * context["scale"], + v, + context["heads"], + context["mesh"], + context["axis_names_q"], + context["axis_names_kv"], + context["flash_block_sizes"], + context["dtype"], + mask_padding_tokens=context["mask_padding_tokens"], + residual_checkpoint_name=context["residual_checkpoint_name"], + attention_mask=context["attention_mask"], + use_base2_exp=context["use_base2_exp"], + use_experimental_scheduler=context["use_experimental_scheduler"], + ulysses_shards=context["ulysses_shards"], + ) + + @register_kernel("flash") def flash_kernel(q, k, v, context): return _tpu_flash_attention( @@ -953,6 +1159,7 @@ def _apply_attention( attention_mask: Array = None, use_base2_exp: bool = False, use_experimental_scheduler: bool = False, + ulysses_shards: int = -1, ): """Routes to different attention kernels using a module-level registry.""" @@ -962,7 +1169,7 @@ def _apply_attention( seq_len_idx = 2 can_use_flash_attention = True - if attention_kernel in ["flash", "tokamax_flash", "ulysses", "ulysses_custom"]: + if attention_kernel in ["flash", "tokamax_flash", "ulysses", "ulysses_custom", "ulysses_ring"]: can_use_flash_attention = ( query.shape[seq_len_idx] >= flash_min_seq_length and key.shape[seq_len_idx] >= flash_min_seq_length @@ -983,6 +1190,7 @@ def _apply_attention( "scale": scale, "use_base2_exp": use_base2_exp, "use_experimental_scheduler": use_experimental_scheduler, + "ulysses_shards": ulysses_shards, "dim_head": dim_head, "split_head_dim": split_head_dim, "float32_qk_product": float32_qk_product, @@ -1188,10 +1396,12 @@ def __init__( residual_checkpoint_name: str | None = None, use_base2_exp: bool = False, use_experimental_scheduler: bool = False, + ulysses_shards: int = -1, ): self.dpa_layer = None self.use_base2_exp = use_base2_exp self.use_experimental_scheduler = use_experimental_scheduler + self.ulysses_shards = ulysses_shards if attention_kernel == "cudnn_flash_te": from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error @@ -1254,6 +1464,7 @@ def apply_attention(self, query: Array, key: Array, value: Array, attention_mask attention_mask=attention_mask, use_base2_exp=self.use_base2_exp if hasattr(self, "use_base2_exp") else False, use_experimental_scheduler=self.use_experimental_scheduler if hasattr(self, "use_experimental_scheduler") else False, + ulysses_shards=(self.ulysses_shards if hasattr(self, "ulysses_shards") else -1), ) @@ -1274,6 +1485,7 @@ class AttentionOp(nn.Module): quant: Quant = None use_base2_exp: bool = False use_experimental_scheduler: bool = False + ulysses_shards: int = -1 def setup(self): self.dpa_layer = None @@ -1321,6 +1533,7 @@ def apply_attention(self, query: Array, key: Array, value: Array, attention_mask attention_mask=attention_mask, use_base2_exp=self.use_base2_exp, use_experimental_scheduler=self.use_experimental_scheduler, + ulysses_shards=self.ulysses_shards, ) @@ -1357,9 +1570,15 @@ def __init__( enable_jax_named_scopes: bool = False, added_kv_proj_dim: Optional[int] = None, # New for I2V image_seq_len: Optional[int] = None, # New for I2V - use_base2_exp: bool = False, - use_experimental_scheduler: bool = False, + attention_config: Optional[dict] = None, ): + attention_config = { + "use_base2_exp": False, + "use_experimental_scheduler": False, + "ulysses_shards": -1, + **(attention_config or {}), + } + if attention_kernel in {"flash", "cudnn_flash_te"} and mesh is None: raise ValueError(f"The flash attention kernel requires a value for mesh, but mesh is {self.mesh}") self.dim_head = dim_head @@ -1379,8 +1598,8 @@ def __init__( else: axis_names_q = (BATCH, CROSS_ATTN_HEAD, CROSS_ATTN_Q_LENGTH, D_KV) axis_names_kv = (BATCH, CROSS_ATTN_HEAD, CROSS_ATTN_KV_LENGTH, D_KV) - if attention_kernel == "tokamax_ring" and not is_self_attention: - attention_kernel = "tokamax_flash" # do not use ring attention for cross attention + if attention_kernel in ("tokamax_ring", "ulysses_ring") and not is_self_attention: + attention_kernel = "tokamax_flash" self.added_kv_proj_dim = added_kv_proj_dim # New for I2V self.image_seq_len = image_seq_len # New for I2V tpu_type = get_tpu_type() @@ -1403,8 +1622,9 @@ def __init__( quant=quant, mask_padding_tokens=mask_padding_tokens, residual_checkpoint_name=residual_checkpoint_name, - use_base2_exp=use_base2_exp, - use_experimental_scheduler=use_experimental_scheduler, + use_base2_exp=attention_config["use_base2_exp"], + use_experimental_scheduler=attention_config["use_experimental_scheduler"], + ulysses_shards=attention_config["ulysses_shards"], ) # None axes corresponds to the stacked weights across all blocks # because of the use of nnx.vmap and nnx.scan. @@ -1617,7 +1837,6 @@ def __call__( query_proj = _unflatten_heads(query_proj, self.heads) key_proj = _unflatten_heads(key_proj, self.heads) value_proj = _unflatten_heads(value_proj, self.heads) - # output of _unflatten_heads Batch, heads, seq_len, head_dim query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb) query_proj = checkpoint_name(query_proj, "query_proj") diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index f5057f50..40c6be3f 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -353,10 +353,15 @@ def __init__( dropout: float = 0.0, mask_padding_tokens: bool = True, enable_jax_named_scopes: bool = False, - use_base2_exp: bool = False, - use_experimental_scheduler: bool = False, + attention_config: Optional[dict] = None, ): self.enable_jax_named_scopes = enable_jax_named_scopes + attention_config = { + "use_base2_exp": False, + "use_experimental_scheduler": False, + "ulysses_shards": -1, + **(attention_config or {}), + } # 1. Self-attention self.norm1 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=False) @@ -379,8 +384,7 @@ def __init__( mask_padding_tokens=mask_padding_tokens, residual_checkpoint_name="self_attn", enable_jax_named_scopes=enable_jax_named_scopes, - use_base2_exp=use_base2_exp, - use_experimental_scheduler=use_experimental_scheduler, + attention_config=attention_config, ) # 1. Cross-attention @@ -405,8 +409,7 @@ def __init__( mask_padding_tokens=mask_padding_tokens, residual_checkpoint_name="cross_attn", enable_jax_named_scopes=enable_jax_named_scopes, - use_base2_exp=use_base2_exp, - use_experimental_scheduler=use_experimental_scheduler, + attention_config=attention_config, ) assert cross_attn_norm is True self.norm2 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=True) @@ -570,14 +573,19 @@ def __init__( mask_padding_tokens: bool = True, scan_layers: bool = True, enable_jax_named_scopes: bool = False, - use_base2_exp: bool = False, - use_experimental_scheduler: bool = False, + attention_config: Optional[dict] = None, ): inner_dim = num_attention_heads * attention_head_dim out_channels = out_channels or in_channels self.num_layers = num_layers self.scan_layers = scan_layers self.enable_jax_named_scopes = enable_jax_named_scopes + attention_config = { + "use_base2_exp": False, + "use_experimental_scheduler": False, + "ulysses_shards": -1, + **(attention_config or {}), + } # 1. Patch & position embedding self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) @@ -637,8 +645,7 @@ def init_block(rngs): enable_jax_named_scopes=enable_jax_named_scopes, added_kv_proj_dim=added_kv_proj_dim, image_seq_len=image_seq_len, - use_base2_exp=use_base2_exp, - use_experimental_scheduler=use_experimental_scheduler, + attention_config=attention_config, ) self.gradient_checkpoint = GradientCheckpointType.from_str(remat_policy) @@ -667,6 +674,7 @@ def init_block(rngs): precision=precision, attention=attention, enable_jax_named_scopes=enable_jax_named_scopes, + attention_config=attention_config, ) blocks.append(block) self.blocks = nnx.data(blocks) diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py b/src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py index fcb9151f..aa70e1e6 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py @@ -97,6 +97,10 @@ def __init__( self.enable_jax_named_scopes = enable_jax_named_scopes self.apply_input_projection = apply_input_projection self.apply_output_projection = apply_output_projection + attention_config = { + "use_base2_exp": use_base2_exp, + "use_experimental_scheduler": use_experimental_scheduler, + } # 1. Input projection self.proj_in = nnx.data([None]) @@ -132,8 +136,7 @@ def __init__( mask_padding_tokens=mask_padding_tokens, residual_checkpoint_name="self_attn", enable_jax_named_scopes=enable_jax_named_scopes, - use_base2_exp=use_base2_exp, - use_experimental_scheduler=use_experimental_scheduler, + attention_config=attention_config, ) # 3. Cross-attention @@ -156,8 +159,7 @@ def __init__( mask_padding_tokens=mask_padding_tokens, residual_checkpoint_name="cross_attn", enable_jax_named_scopes=enable_jax_named_scopes, - use_base2_exp=use_base2_exp, - use_experimental_scheduler=use_experimental_scheduler, + attention_config=attention_config, ) assert cross_attn_norm is True, "cross_attn_norm must be True" self.norm2 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=True) @@ -342,6 +344,10 @@ def __init__( self.num_layers = num_layers self.scan_layers = scan_layers self.enable_jax_named_scopes = enable_jax_named_scopes + attention_config = { + "use_base2_exp": use_base2_exp, + "use_experimental_scheduler": use_experimental_scheduler, + } # 1. Patch & position embedding self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) @@ -401,8 +407,7 @@ def __init__( dropout=dropout, mask_padding_tokens=mask_padding_tokens, enable_jax_named_scopes=enable_jax_named_scopes, - use_base2_exp=use_base2_exp, - use_experimental_scheduler=use_experimental_scheduler, + attention_config=attention_config, ) blocks.append(block) self.blocks = blocks @@ -433,8 +438,8 @@ def __init__( enable_jax_named_scopes=enable_jax_named_scopes, apply_input_projection=vace_block_id == 0, apply_output_projection=True, - use_base2_exp=use_base2_exp, - use_experimental_scheduler=use_experimental_scheduler, + use_base2_exp=attention_config["use_base2_exp"], + use_experimental_scheduler=attention_config["use_experimental_scheduler"], ) vace_blocks.append(vace_block) self.vace_blocks = vace_blocks diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 5a5cfa29..81172d3e 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -138,8 +138,11 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): wan_config["mask_padding_tokens"] = config.mask_padding_tokens wan_config["scan_layers"] = config.scan_layers wan_config["enable_jax_named_scopes"] = config.enable_jax_named_scopes - wan_config["use_base2_exp"] = config.use_base2_exp - wan_config["use_experimental_scheduler"] = config.use_experimental_scheduler + wan_config["attention_config"] = { + "use_base2_exp": config.use_base2_exp, + "use_experimental_scheduler": config.use_experimental_scheduler, + "ulysses_shards": getattr(config, "ulysses_shards", -1), + } # 2. eval_shape - will not use flops or create weights on device # thus not using HBM memory. diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index 9c6c9125..17edd5cc 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -38,6 +38,7 @@ RING_ATTENTION_AXIS_RULES, SEQUENCE_PARALLEL_AXIS_RULES, ULYSSES_ATTENTION_AXIS_RULES, + ULYSSES_RING_ATTENTION_AXIS_RULES, ) _ALLOWED_MODEL_NAMES = {WAN2_1, WAN2_2, LTX2_VIDEO, LTX2_3} @@ -214,10 +215,11 @@ def user_init(raw_keys): raw_keys["vae_logical_axis_rules"] = _lists_to_tuples(raw_keys["vae_logical_axis_rules"]) # Verify qkv is sharded across sequence. attention = raw_keys["attention"] - uses_ring_attention = "ring" in attention - uses_ulysses_attention = "ulysses" in attention + uses_ulysses_ring_attention = attention == "ulysses_ring" + uses_ring_attention = "ring" in attention and not uses_ulysses_ring_attention + uses_ulysses_attention = "ulysses" in attention and not uses_ulysses_ring_attention uses_uniform_sequence_sharding = raw_keys["attention_sharding_uniform"] - if uses_ring_attention or uses_ulysses_attention or uses_uniform_sequence_sharding: + if uses_ring_attention or uses_ulysses_attention or uses_ulysses_ring_attention or uses_uniform_sequence_sharding: max_logging.log( "Adding sequence sharding to q and kv if not already present because " f"{attention=} requires it or attention_sharding_uniform={uses_uniform_sequence_sharding} is set." @@ -233,7 +235,12 @@ def user_init(raw_keys): if kv_seq_sharding not in logical_axis_rules: logical_axis_rules.append(kv_seq_sharding) max_logging.log(f"Adding key/value sequence axis rule {kv_seq_sharding}") - if uses_ring_attention: + if uses_ulysses_ring_attention: + for ulysses_ring_attention_axis_rule in ULYSSES_RING_ATTENTION_AXIS_RULES: + if ulysses_ring_attention_axis_rule not in logical_axis_rules: + max_logging.log(f"Adding ulysses ring attention axis rule {ulysses_ring_attention_axis_rule}") + new_rules.append(ulysses_ring_attention_axis_rule) + elif uses_ring_attention: for ring_attention_axis_rule in RING_ATTENTION_AXIS_RULES: if ring_attention_axis_rule not in logical_axis_rules: max_logging.log(f"Adding ring attention axis rule {ring_attention_axis_rule}") diff --git a/src/maxdiffusion/tests/attention_test.py b/src/maxdiffusion/tests/attention_test.py index 5c95dff8..ed4bf5cc 100644 --- a/src/maxdiffusion/tests/attention_test.py +++ b/src/maxdiffusion/tests/attention_test.py @@ -43,6 +43,10 @@ def _ulysses_mesh(self): devices = np.array(jax.devices()[:2]).reshape(1, 1, 2, 1) return Mesh(devices, ("data", "fsdp", "context", "tensor")) + def _ulysses_ring_mesh(self): + devices = np.array(jax.devices()[:4]).reshape(1, 1, 4, 1) + return Mesh(devices, ("data", "fsdp", "context", "tensor")) + def _ulysses_axis_rules(self): return ( (attention_flax.BATCH, "data"), @@ -52,6 +56,15 @@ def _ulysses_axis_rules(self): (attention_flax.D_KV, None), ) + def _ulysses_ring_axis_rules(self): + return ( + (attention_flax.BATCH, "data"), + (attention_flax.SELF_ATTN_HEAD, None), + (attention_flax.SELF_ATTN_Q_LENGTH, "context"), + (attention_flax.SELF_ATTN_KV_LENGTH, "context"), + (attention_flax.D_KV, None), + ) + def _flash_axis_rules(self): return ( (attention_flax.BATCH, "data"), @@ -441,6 +454,234 @@ def fake_kernel(q, k, v, segment_ids): self.assertEqual(output.shape, query.shape) self.assertTrue(jnp.array_equal(output, expected)) + @unittest.skipIf(len(jax.devices()) < 4, "Ulysses ring attention layout test requires at least 4 devices.") + def test_ulysses_ring_attention_round_trips_query_when_heads_are_divisible(self): + """Hybrid Ulysses+ring should preserve layout while only exposing context.""" + batch = 2 + length = 8 + heads = 4 + head_depth = 4 + query = jnp.arange(batch * length * heads * head_depth, dtype=jnp.float32).reshape(batch, length, heads * head_depth) + key = query + 1000.0 + value = query + 2000.0 + mesh = self._ulysses_ring_mesh() + + def fake_make_ring_attention(**unused_kwargs): + def fake_kernel(q, k, v, segment_ids): + del k, v, segment_ids + return q + + return fake_kernel + + with ( + mesh, + nn_partitioning.axis_rules(self._ulysses_ring_axis_rules()), + mock.patch.object( + attention_flax.tokamax_ring_attention_kernel, + "make_ring_attention", + side_effect=fake_make_ring_attention, + ), + ): + output = attention_flax._ulysses_ring_attention( + query, + key, + value, + heads=heads, + mesh=mesh, + axis_names_q=( + attention_flax.BATCH, + attention_flax.SELF_ATTN_HEAD, + attention_flax.SELF_ATTN_Q_LENGTH, + attention_flax.D_KV, + ), + axis_names_kv=( + attention_flax.BATCH, + attention_flax.SELF_ATTN_HEAD, + attention_flax.SELF_ATTN_KV_LENGTH, + attention_flax.D_KV, + ), + flash_block_sizes=self._ulysses_block_sizes(), + dtype=jnp.float32, + ulysses_shards=4, + ) + + self.assertEqual(output.shape, query.shape) + self.assertTrue(jnp.array_equal(output, query)) + + @unittest.skipIf(len(jax.devices()) < 4, "Ulysses ring attention mask test requires at least 4 devices.") + def test_ulysses_ring_attention_masks_global_kv_padding(self): + """Hybrid Ulysses+ring masks padding at the end of the global sequence.""" + batch = 1 + length = 7 + heads = 4 + head_depth = 4 + query = jnp.arange(batch * length * heads * head_depth, dtype=jnp.float32).reshape(batch, length, heads * head_depth) + key = query + 1000.0 + value = query + 2000.0 + mesh = self._ulysses_ring_mesh() + masks = [] + + def fake_make_ring_attention(**kwargs): + masks.append(kwargs["mask"].array.copy()) + + def fake_kernel(q, k, v, segment_ids): + del k, v, segment_ids + return q + + return fake_kernel + + with ( + mesh, + nn_partitioning.axis_rules(self._ulysses_ring_axis_rules()), + mock.patch.object( + attention_flax.tokamax_ring_attention_kernel, + "make_ring_attention", + side_effect=fake_make_ring_attention, + ), + ): + output = attention_flax._ulysses_ring_attention( + query, + key, + value, + heads=heads, + mesh=mesh, + axis_names_q=( + attention_flax.BATCH, + attention_flax.SELF_ATTN_HEAD, + attention_flax.SELF_ATTN_Q_LENGTH, + attention_flax.D_KV, + ), + axis_names_kv=( + attention_flax.BATCH, + attention_flax.SELF_ATTN_HEAD, + attention_flax.SELF_ATTN_KV_LENGTH, + attention_flax.D_KV, + ), + flash_block_sizes=self._ulysses_block_sizes(), + dtype=jnp.float32, + ulysses_shards=2, + ) + + self.assertEqual(output.shape, query.shape) + self.assertTrue(jnp.array_equal(output, query)) + self.assertEqual(len(masks), 1) + np.testing.assert_array_equal(masks[0][:, :length], np.ones((4, length), dtype=bool)) + np.testing.assert_array_equal(masks[0][:, length:], np.zeros((4, 1), dtype=bool)) + + def test_ulysses_ring_attention_raises_when_heads_are_not_divisible_by_ulysses_shards(self): + """The hidden all-to-all head split still requires divisible heads.""" + if len(jax.devices()) < 4: + self.skipTest("Ulysses ring attention validation test requires at least 4 devices.") + batch = 2 + length = 8 + heads = 3 + head_depth = 4 + query = jnp.arange(batch * length * heads * head_depth, dtype=jnp.float32).reshape(batch, length, heads * head_depth) + key = query + 1000.0 + value = query + 2000.0 + mesh = self._ulysses_ring_mesh() + + with mesh, nn_partitioning.axis_rules(self._ulysses_ring_axis_rules()): + with self.assertRaisesRegex( + ValueError, + r"heads=3 and ulysses_shards=2", + ): + attention_flax._ulysses_ring_attention( + query, + key, + value, + heads=heads, + mesh=mesh, + axis_names_q=( + attention_flax.BATCH, + attention_flax.SELF_ATTN_HEAD, + attention_flax.SELF_ATTN_Q_LENGTH, + attention_flax.D_KV, + ), + axis_names_kv=( + attention_flax.BATCH, + attention_flax.SELF_ATTN_HEAD, + attention_flax.SELF_ATTN_KV_LENGTH, + attention_flax.D_KV, + ), + flash_block_sizes=self._ulysses_block_sizes(), + dtype=jnp.float32, + ulysses_shards=2, + ) + + def test_ulysses_ring_attention_raises_when_ulysses_shards_are_not_set(self): + if len(jax.devices()) < 4: + self.skipTest("Ulysses ring attention validation test requires at least 4 devices.") + batch = 2 + length = 8 + heads = 4 + head_depth = 4 + query = jnp.arange(batch * length * heads * head_depth, dtype=jnp.float32).reshape(batch, length, heads * head_depth) + key = query + 1000.0 + value = query + 2000.0 + mesh = self._ulysses_ring_mesh() + + with mesh, nn_partitioning.axis_rules(self._ulysses_ring_axis_rules()): + with self.assertRaisesRegex(ValueError, r"ulysses_shards"): + attention_flax._ulysses_ring_attention( + query, + key, + value, + heads=heads, + mesh=mesh, + axis_names_q=( + attention_flax.BATCH, + attention_flax.SELF_ATTN_HEAD, + attention_flax.SELF_ATTN_Q_LENGTH, + attention_flax.D_KV, + ), + axis_names_kv=( + attention_flax.BATCH, + attention_flax.SELF_ATTN_HEAD, + attention_flax.SELF_ATTN_KV_LENGTH, + attention_flax.D_KV, + ), + flash_block_sizes=self._ulysses_block_sizes(), + dtype=jnp.float32, + ) + + def test_ulysses_ring_attention_raises_when_ulysses_shards_do_not_divide_context(self): + if len(jax.devices()) < 4: + self.skipTest("Ulysses ring attention validation test requires at least 4 devices.") + batch = 2 + length = 8 + heads = 4 + head_depth = 4 + query = jnp.arange(batch * length * heads * head_depth, dtype=jnp.float32).reshape(batch, length, heads * head_depth) + key = query + 1000.0 + value = query + 2000.0 + mesh = self._ulysses_ring_mesh() + + with mesh, nn_partitioning.axis_rules(self._ulysses_ring_axis_rules()): + with self.assertRaisesRegex(ValueError, r"context_shards=4 and ulysses_shards=3"): + attention_flax._ulysses_ring_attention( + query, + key, + value, + heads=heads, + mesh=mesh, + axis_names_q=( + attention_flax.BATCH, + attention_flax.SELF_ATTN_HEAD, + attention_flax.SELF_ATTN_Q_LENGTH, + attention_flax.D_KV, + ), + axis_names_kv=( + attention_flax.BATCH, + attention_flax.SELF_ATTN_HEAD, + attention_flax.SELF_ATTN_KV_LENGTH, + attention_flax.D_KV, + ), + flash_block_sizes=self._ulysses_block_sizes(), + dtype=jnp.float32, + ulysses_shards=3, + ) + if __name__ == "__main__": absltest.main()