-
Notifications
You must be signed in to change notification settings - Fork 73
Ring + Ulysses 2D context parallelism #404
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -64,9 +64,11 @@ jit_initializers: True | |
| # Set true to load weights from pytorch | ||
| from_pt: True | ||
| split_head_dim: True | ||
| attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom | ||
| attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom, ulysses_ring | ||
| use_base2_exp: True | ||
| use_experimental_scheduler: True | ||
| # For attention=ulysses_ring, hidden Ulysses shard count; ring shards are context / this. | ||
| ulysses_shards: -1 | ||
| flash_min_seq_length: 4096 | ||
| dropout: 0.0 | ||
|
|
||
|
|
@@ -81,14 +83,14 @@ mask_padding_tokens: True | |
| attention_sharding_uniform: True | ||
|
|
||
| flash_block_sizes: { | ||
| "block_q" : 512, | ||
| "block_kv_compute" : 512, | ||
| "block_kv" : 512, | ||
| "block_q_dkv" : 512, | ||
| "block_kv_dkv" : 512, | ||
| "block_kv_dkv_compute" : 512, | ||
| "block_q_dq" : 512, | ||
| "block_kv_dq" : 512, | ||
| "block_q" : 2048, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟡 |
||
| "block_kv_compute" : 1024, | ||
| "block_kv" : 2048, | ||
| "block_q_dkv" : 2048, | ||
| "block_kv_dkv" : 2048, | ||
| "block_kv_dkv_compute" : 1024, | ||
| "block_q_dq" : 2048, | ||
| "block_kv_dq" : 2048, | ||
| "use_fused_bwd_kernel": False, | ||
| } | ||
| # Use on v6e | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🟡 This axis rule shards the cross-attention KV length as
None(replicated), which matchesSEQUENCE_PARALLEL_AXIS_RULES. Given thatulysses_ringfalls back totokamax_flashfor cross-attention, this is consistent and ensures compatibility when the encoder sequence is not sharded across the context mesh.