Ring + Ulysses 2D context parallelism#404
Conversation
|
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
d6638b6 to
2582cf1
Compare
48a900e to
d10bf07
Compare
|
🤖 Hi @csgoogle, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
d10bf07 to
6521956
Compare
There was a problem hiding this comment.
This Pull Request introduces a hybrid Ulysses + Ring 2D context parallelism for attention, which is a significant addition for handling long sequences in large-scale diffusion models. The implementation correctly sets up the internal mesh and handles the necessary all_to_all communications. However, there are some concerns regarding the masking logic and the handling of the attention_mask parameter in the new kernel that should be addressed to ensure correctness across all use cases.
🔍 General Feedback
- Positive Highlights: The refactoring of
Attentionclasses to use a more flexibleattention_configdictionary is a great improvement for maintainability and reduces the complexity of method signatures. - Robustness: The validation checks for
ulysses_shardsand mesh axes are well-implemented and will help users avoid misconfiguration. - Testing: New tests for the Ulysses-Ring layout are included, although more comprehensive tests covering the actual attention computation with padding would further improve confidence.
| "block_kv_dkv_compute" : 512, | ||
| "block_q_dq" : 512, | ||
| "block_kv_dq" : 512, | ||
| "block_q" : 2048, |
There was a problem hiding this comment.
🟡 block_q and block_kv were increased from 512 to 2048 in this config. Is this change required for the new ulysses_ring attention kernel or is it a general performance optimization? It might be worth documenting why this specific model size received this update, or applying it to others if it's a general improvement.
| f" axis, batch dimension: {query.shape[0]}, devices_in_batch_sharding: {devices_in_batch_sharding}" | ||
| ) | ||
| x = wrap_ulysses_attention(query, key, value) | ||
| x = x[:, :, :orig_q_seq_len, :] |
There was a problem hiding this comment.
🟠 The attention_mask parameter is defined but not used in the _ulysses_ring_attention implementation. Other attention kernels in this file use this mask to handle sequence padding or specific conditioning (like I2V). This feature should be implemented to ensure consistency and correctness for models that rely on attention masking.
|
|
||
| block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, "tokamax_ring") | ||
|
|
||
| q_axis_names = nn.logical_to_mesh_axes(axis_names_q) |
There was a problem hiding this comment.
🟠 The use of np.tile(kv_shard_valid, num_ring_shards) assumes that every shard in the ring has an identical valid/padding pattern. In sequence parallelism, padding is typically only at the end of the global sequence (affecting only the last shard). Tiling a local mask will incorrectly mask valid tokens in earlier shards or fail to mask padding tokens in the last shard correctly.
Consider calculating the mask based on global sequence positions:
# Example logic for global masking
ring_idx = jax.lax.axis_index(ring_axis)
q_global_start = ring_idx * query_seq_len
# ... logic to create a mask that correctly accounts for orig_q_seq_len across the ringThere was a problem hiding this comment.
thanks, fixed it.
6521956 to
5c4d053
Compare
|
🤖 Hi @csgoogle, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This PR successfully implements the ulysses_ring attention mode, a sophisticated 2D parallelism strategy that combines Ulysses all-to-all head sharding with Ring KV rotation. The implementation is clean, well-integrated into the existing WAN model infrastructure, and includes a comprehensive suite of unit tests verifying both functional correctness and validation logic.
🔍 General Feedback
- Elegant Mesh Reshaping: The internal reshaping of the
contextmesh axis into privateringandulyssesaxes withinshard_mapis a great architectural choice, keeping the public configuration simple. - Robust Masking: The global sequence masking logic in the hybrid kernel correctly handles padding, which is critical for accuracy in sequence-parallel settings.
- Thorough Testing: The addition of detailed round-trip and validation tests in
attention_test.pyis excellent and ensures the new feature is robust against edge cases.
| [SELF_ATTN_KV_LENGTH, CONTEXT], | ||
| [CROSS_ATTN_HEAD, None], | ||
| [CROSS_ATTN_Q_LENGTH, CONTEXT], | ||
| [CROSS_ATTN_KV_LENGTH, None], |
There was a problem hiding this comment.
🟡 This axis rule shards the cross-attention KV length as None (replicated), which matches SEQUENCE_PARALLEL_AXIS_RULES. Given that ulysses_ring falls back to tokamax_flash for cross-attention, this is consistent and ensures compatibility when the encoder sequence is not sharded across the context mesh.
| [CROSS_ATTN_KV_LENGTH, None], | |
| [CROSS_ATTN_KV_LENGTH, None], |
| key_proj = _unflatten_heads(key_proj, self.heads) | ||
| value_proj = _unflatten_heads(value_proj, self.heads) | ||
| # output of _unflatten_heads Batch, heads, seq_len, head_dim | ||
| query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb) |
There was a problem hiding this comment.
🟢 It might be helpful to keep or restore the comment explaining the output shape of _unflatten_heads, as it clarifies the tensor layout for anyone reading this section.
| query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb) | |
| query_proj = _unflatten_heads(query_proj, self.heads) | |
| key_proj = _unflatten_heads(key_proj, self.heads) | |
| value_proj = _unflatten_heads(value_proj, self.heads) | |
| # output of _unflatten_heads Batch, heads, seq_len, head_dim |
| ), | ||
| save_residuals=False, | ||
| ring_axis=ring_axis, | ||
| kv_seq_shards=num_ring_shards, |
There was a problem hiding this comment.
🟡 The use of np.arange and tokamax_splash_attention_mask.NumpyMask inside the shard_map function assumes that q_padded_len, total_kv_len, query_seq_len, and key_seq_len are static values at trace time. This is currently true for TPU workloads in this project, but if dynamic shapes are introduced in the future, this would need to transition to jnp and a dynamic mask implementation.
| query = jax.lax.all_to_all(query, axis_name=ulysses_axis, split_axis=1, concat_axis=2, tiled=True) | ||
| key = jax.lax.all_to_all(key, axis_name=ulysses_axis, split_axis=1, concat_axis=2, tiled=True) | ||
| value = jax.lax.all_to_all(value, axis_name=ulysses_axis, split_axis=1, concat_axis=2, tiled=True) | ||
|
|
There was a problem hiding this comment.
🟢 Excellent and thorough validation of ulysses_shards. Ensuring it divides the context_axis and that heads are divisible prevents difficult-to-debug sharding errors later in the execution.
| self.assertEqual(output.shape, query.shape) | ||
| self.assertTrue(jnp.array_equal(output, expected)) | ||
|
|
||
| @unittest.skipIf(len(jax.devices()) < 4, "Ulysses ring attention layout test requires at least 4 devices.") |
There was a problem hiding this comment.
🟢 These new tests are very well-structured, covering layout round-trips, global KV padding masking, and various validation error cases. This provides high confidence in the new attention kernel's correctness.
Description
This PR adds support for a new
ulysses_ringattention mode for WAN models. The implementation keeps the public sequence sharding on the existingcontextmesh axis, then internally reshapes that axis into privateringandulyssesaxes so the attention path can combine Ulysses all-to-all head sharding with ring-based KV rotation.Changes
ulysses_ringattention kernel registration and routing.contextinto hiddenringandulyssesaxes.ulysses_shardsconfig plumbing through WAN pipeline, WAN transformer blocks, and attention ops.ulysses_ringsupport and addulysses_shards.Testing
src/maxdiffusion/tests/attention_test.pyforulysses_ringbehavior and validation.