Skip to content

[SPMD] Mesh to support custom device order.#4162

Merged
yeounoh merged 1 commit intomasterfrom
spmd_device_mesh
Nov 8, 2022
Merged

[SPMD] Mesh to support custom device order.#4162
yeounoh merged 1 commit intomasterfrom
spmd_device_mesh

Conversation

@yeounoh
Copy link
Copy Markdown
Contributor

@yeounoh yeounoh commented Nov 7, 2022

This implements Mesh class from #3871 , to support custom device order in logical XLA device mesh topology.

@yeounoh yeounoh added the distributed SPMD and other distributed things. label Nov 7, 2022
@yeounoh yeounoh self-assigned this Nov 7, 2022
Comment thread test/test_xla_sharding.py
), "PyTorch/XLA SPMD requires PJRT_DEVICE={CPU, TPU}, GPU is currently not supported."
)
@unittest.skipIf(not using_pjrt() or xm.get_xla_supported_devices("GPU"),
f"Requires PJRT_DEVICE set to `TPU` or `CPU`.")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@will-cromar I think PJRT-GPU single core is ready now?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's blocked from the our SPMD side, once we support TPU, the transition should be easier to GPU -- maybe sometime next year once we are done with the basic/core SPMD features?

Comment thread torch_xla/csrc/tensor_util.cpp
Comment thread torch_xla/experimental/xla_sharding.py Outdated

Args:
device_ids (Union[np.ndarray, List]): A raveled list of devices (IDs) in a custom order. The list is reshaped
to an `mesh_shape` array, filling the elements using C-like index order. For example,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is the example lol?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh ok it is below, you might want to change the wording here.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Copy Markdown
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Copy link
Copy Markdown
Collaborator

@jonb377 jonb377 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

Comment thread torch_xla/experimental/xla_sharding.py Outdated
Comment on lines 78 to 80
mesh_shape (Tuple[Union[int, None]]): A int tuple describing the logical topology
of the device mesh, and each element describes the number of devices in
the corresponding axis.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like mesh_shape can be removed here

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch :)

Comment thread test/test_xla_sharding.py

def test_custom_tile_assignment(self):
xt = torch.randn(10, 20).to(device=xm.xla_device())
mesh_shape = (1, self.n_devices)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see the tests have all devices mapped to a single axis - is there anything stopping us from using e.g. mesh_shape = (2, self.n_devices / 2)?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, but for the unit testing a flat mesh is easier to work with since we don't know how many devices we would have (e.g., for CPU, we will have 1).

def __init__(self,
device_ids: Union[np.ndarray, List],
mesh_shape: Tuple[int, ...],
axis_names: Tuple[str, ...] = None):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious - how will axis_names be used long-term? Is it just for annotating the mesh?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question, mesh axis annotation is useful since it makes the annotation logic more readable. We can also build a partitioning rule based on the axis name, instead of int indices.

@yeounoh
Copy link
Copy Markdown
Contributor Author

yeounoh commented Nov 8, 2022

narrow_copy_dense has been renamed to narrow_copy_dense_symint in the upstream, rebasing to fix the build issue.

@yeounoh yeounoh force-pushed the spmd_device_mesh branch 3 times, most recently from 3d6bc93 to eea8e9c Compare November 8, 2022 19:33
Copy link
Copy Markdown
Collaborator

@steventk-g steventk-g left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm!

@yeounoh yeounoh merged commit b096c5c into master Nov 8, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

distributed SPMD and other distributed things.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants