[SPMD] Preserve parameter sharding with output data sharding#4721
Merged
[SPMD] Preserve parameter sharding with output data sharding#4721
Conversation
54c51f1 to
420d701
Compare
JackCaoG
reviewed
Mar 4, 2023
JackCaoG
reviewed
Mar 4, 2023
JackCaoG
reviewed
Mar 4, 2023
JackCaoG
reviewed
Mar 4, 2023
JackCaoG
reviewed
Mar 4, 2023
3eac5e6 to
f26b305
Compare
5c3e631 to
0ddee73
Compare
JackCaoG
reviewed
Mar 7, 2023
JackCaoG
reviewed
Mar 7, 2023
JackCaoG
reviewed
Mar 7, 2023
JackCaoG
reviewed
Mar 7, 2023
a90760e to
e45ab94
Compare
Collaborator
|
|
e45ab94 to
5ba829f
Compare
5ba829f to
26279e3
Compare
Contributor
Author
Yea, we need at least 2 devices to create Hlo sharding. Added the safeguard. |
JackCaoG
reviewed
Mar 7, 2023
26279e3 to
8d83ef4
Compare
jonb377
approved these changes
Mar 8, 2023
| } | ||
|
|
||
| void Assign(const Data& data) override { | ||
| XLA_ERROR() << __FUNCTION__ << " not supported."; |
Collaborator
There was a problem hiding this comment.
Nice! We can retry the simple MpDeviceLoader hack for SPMD once this lands, this was the blocker.
mateuszlewko
pushed a commit
that referenced
this pull request
Mar 15, 2023
[SPMD] Persist tensor sharding with XLA sharding propagation
ManfeiBai
pushed a commit
to ManfeiBai/PyTorchXLA
that referenced
this pull request
Mar 29, 2023
…#4721) [SPMD] Persist tensor sharding with XLA sharding propagation
ManfeiBai
pushed a commit
to ManfeiBai/PyTorchXLA
that referenced
this pull request
Mar 29, 2023
…#4721) [SPMD] Persist tensor sharding with XLA sharding propagation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This addresses the same problem as in #4696 with an alternative solution. We shard the replicated output while handling the computation results. This avoids post traversal pass to replace original data node with a sharded one, thus more efficient. Key changes include:
ShardingUtil::OutputHandlerXLAShardingTest.OutputHandlertest for unit testing,test_optimizer_step_with_shardingchecks the validity of the change with a simple e2e example already.std::optional<xla::Shape>toShardingSpecstd::optional<xla::OpSharding>toPjRtShardedDatastd::vector<XLATensor::ShardingSpecPtr>param toXLAGraphExecutor::ScheduleSyncTensorsGraph, since the async function now callsShardingUtil::OutputHandlerXLAGraphExecutor::CollectShardingSpecsbefore callingScheduleSyncTensorsGraphWrapDataShardsandGetDataShardingAPIs inComputationClient.