Skip to content

Ring + Ulysses 2D context parallelism#404

Open
csgoogle wants to merge 1 commit into
mainfrom
wan-ulysses-bshd-attention
Open

Ring + Ulysses 2D context parallelism#404
csgoogle wants to merge 1 commit into
mainfrom
wan-ulysses-bshd-attention

Conversation

@csgoogle
Copy link
Copy Markdown
Collaborator

@csgoogle csgoogle commented May 12, 2026

Description
This PR adds support for a new ulysses_ring attention mode for WAN models. The implementation keeps the public sequence sharding on the existing context mesh axis, then internally reshapes that axis into private ring and ulysses axes so the attention path can combine Ulysses all-to-all head sharding with ring-based KV rotation.

Changes

  • Adds ulysses_ring attention kernel registration and routing.
  • Introduces internal mesh reshaping from context into hidden ring and ulysses axes.
  • Adds ulysses_shards config plumbing through WAN pipeline, WAN transformer blocks, and attention ops.
  • Adds logical axis rules for Ulysses ring attention.
  • Updates WAN config files to document ulysses_ring support and add ulysses_shards.
  • Adds tests covering layout round-trip behavior and validation errors for invalid Ulysses shard settings.

Testing

  • Added unit tests in src/maxdiffusion/tests/attention_test.py for ulysses_ring behavior and validation.

@google-cla
Copy link
Copy Markdown

google-cla Bot commented May 12, 2026

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@csgoogle csgoogle force-pushed the wan-ulysses-bshd-attention branch from d6638b6 to 2582cf1 Compare May 15, 2026 11:36
@csgoogle csgoogle changed the title Wan ulysses bshd attention Ring + Ulysses 2D context parallelism May 15, 2026
@csgoogle csgoogle force-pushed the wan-ulysses-bshd-attention branch 3 times, most recently from 48a900e to d10bf07 Compare May 15, 2026 19:27
@github-actions
Copy link
Copy Markdown

🤖 Hi @csgoogle, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@csgoogle csgoogle force-pushed the wan-ulysses-bshd-attention branch from d10bf07 to 6521956 Compare May 15, 2026 19:29
Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This Pull Request introduces a hybrid Ulysses + Ring 2D context parallelism for attention, which is a significant addition for handling long sequences in large-scale diffusion models. The implementation correctly sets up the internal mesh and handles the necessary all_to_all communications. However, there are some concerns regarding the masking logic and the handling of the attention_mask parameter in the new kernel that should be addressed to ensure correctness across all use cases.

🔍 General Feedback

  • Positive Highlights: The refactoring of Attention classes to use a more flexible attention_config dictionary is a great improvement for maintainability and reduces the complexity of method signatures.
  • Robustness: The validation checks for ulysses_shards and mesh axes are well-implemented and will help users avoid misconfiguration.
  • Testing: New tests for the Ulysses-Ring layout are included, although more comprehensive tests covering the actual attention computation with padding would further improve confidence.

"block_kv_dkv_compute" : 512,
"block_q_dq" : 512,
"block_kv_dq" : 512,
"block_q" : 2048,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 block_q and block_kv were increased from 512 to 2048 in this config. Is this change required for the new ulysses_ring attention kernel or is it a general performance optimization? It might be worth documenting why this specific model size received this update, or applying it to others if it's a general improvement.

f" axis, batch dimension: {query.shape[0]}, devices_in_batch_sharding: {devices_in_batch_sharding}"
)
x = wrap_ulysses_attention(query, key, value)
x = x[:, :, :orig_q_seq_len, :]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟠 The attention_mask parameter is defined but not used in the _ulysses_ring_attention implementation. Other attention kernels in this file use this mask to handle sequence padding or specific conditioning (like I2V). This feature should be implemented to ensure consistency and correctness for models that rely on attention masking.


block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, "tokamax_ring")

q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟠 The use of np.tile(kv_shard_valid, num_ring_shards) assumes that every shard in the ring has an identical valid/padding pattern. In sequence parallelism, padding is typically only at the end of the global sequence (affecting only the last shard). Tiling a local mask will incorrectly mask valid tokens in earlier shards or fail to mask padding tokens in the last shard correctly.

Consider calculating the mask based on global sequence positions:

    # Example logic for global masking
    ring_idx = jax.lax.axis_index(ring_axis)
    q_global_start = ring_idx * query_seq_len
    # ... logic to create a mask that correctly accounts for orig_q_seq_len across the ring

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, fixed it.

@csgoogle csgoogle force-pushed the wan-ulysses-bshd-attention branch from 6521956 to 5c4d053 Compare May 15, 2026 19:39
@csgoogle csgoogle marked this pull request as ready for review May 15, 2026 19:57
@csgoogle csgoogle requested a review from entrpn as a code owner May 15, 2026 19:57
@github-actions
Copy link
Copy Markdown

🤖 Hi @csgoogle, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This PR successfully implements the ulysses_ring attention mode, a sophisticated 2D parallelism strategy that combines Ulysses all-to-all head sharding with Ring KV rotation. The implementation is clean, well-integrated into the existing WAN model infrastructure, and includes a comprehensive suite of unit tests verifying both functional correctness and validation logic.

🔍 General Feedback

  • Elegant Mesh Reshaping: The internal reshaping of the context mesh axis into private ring and ulysses axes within shard_map is a great architectural choice, keeping the public configuration simple.
  • Robust Masking: The global sequence masking logic in the hybrid kernel correctly handles padding, which is critical for accuracy in sequence-parallel settings.
  • Thorough Testing: The addition of detailed round-trip and validation tests in attention_test.py is excellent and ensures the new feature is robust against edge cases.

[SELF_ATTN_KV_LENGTH, CONTEXT],
[CROSS_ATTN_HEAD, None],
[CROSS_ATTN_Q_LENGTH, CONTEXT],
[CROSS_ATTN_KV_LENGTH, None],
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 This axis rule shards the cross-attention KV length as None (replicated), which matches SEQUENCE_PARALLEL_AXIS_RULES. Given that ulysses_ring falls back to tokamax_flash for cross-attention, this is consistent and ensures compatibility when the encoder sequence is not sharded across the context mesh.

Suggested change
[CROSS_ATTN_KV_LENGTH, None],
[CROSS_ATTN_KV_LENGTH, None],

key_proj = _unflatten_heads(key_proj, self.heads)
value_proj = _unflatten_heads(value_proj, self.heads)
# output of _unflatten_heads Batch, heads, seq_len, head_dim
query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 It might be helpful to keep or restore the comment explaining the output shape of _unflatten_heads, as it clarifies the tensor layout for anyone reading this section.

Suggested change
query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb)
query_proj = _unflatten_heads(query_proj, self.heads)
key_proj = _unflatten_heads(key_proj, self.heads)
value_proj = _unflatten_heads(value_proj, self.heads)
# output of _unflatten_heads Batch, heads, seq_len, head_dim

),
save_residuals=False,
ring_axis=ring_axis,
kv_seq_shards=num_ring_shards,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 The use of np.arange and tokamax_splash_attention_mask.NumpyMask inside the shard_map function assumes that q_padded_len, total_kv_len, query_seq_len, and key_seq_len are static values at trace time. This is currently true for TPU workloads in this project, but if dynamic shapes are introduced in the future, this would need to transition to jnp and a dynamic mask implementation.

query = jax.lax.all_to_all(query, axis_name=ulysses_axis, split_axis=1, concat_axis=2, tiled=True)
key = jax.lax.all_to_all(key, axis_name=ulysses_axis, split_axis=1, concat_axis=2, tiled=True)
value = jax.lax.all_to_all(value, axis_name=ulysses_axis, split_axis=1, concat_axis=2, tiled=True)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 Excellent and thorough validation of ulysses_shards. Ensuring it divides the context_axis and that heads are divisible prevents difficult-to-debug sharding errors later in the execution.

self.assertEqual(output.shape, query.shape)
self.assertTrue(jnp.array_equal(output, expected))

@unittest.skipIf(len(jax.devices()) < 4, "Ulysses ring attention layout test requires at least 4 devices.")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 These new tests are very well-structured, covering layout round-trips, global KV padding masking, and various validation error cases. This provides high confidence in the new attention kernel's correctness.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant