[SPMD] Mesh to support custom device order.#4162
Conversation
| ), "PyTorch/XLA SPMD requires PJRT_DEVICE={CPU, TPU}, GPU is currently not supported." | ||
| ) | ||
| @unittest.skipIf(not using_pjrt() or xm.get_xla_supported_devices("GPU"), | ||
| f"Requires PJRT_DEVICE set to `TPU` or `CPU`.") |
There was a problem hiding this comment.
@will-cromar I think PJRT-GPU single core is ready now?
There was a problem hiding this comment.
It's blocked from the our SPMD side, once we support TPU, the transition should be easier to GPU -- maybe sometime next year once we are done with the basic/core SPMD features?
|
|
||
| Args: | ||
| device_ids (Union[np.ndarray, List]): A raveled list of devices (IDs) in a custom order. The list is reshaped | ||
| to an `mesh_shape` array, filling the elements using C-like index order. For example, |
There was a problem hiding this comment.
where is the example lol?
There was a problem hiding this comment.
oh ok it is below, you might want to change the wording here.
970536f to
716865c
Compare
| mesh_shape (Tuple[Union[int, None]]): A int tuple describing the logical topology | ||
| of the device mesh, and each element describes the number of devices in | ||
| the corresponding axis. |
There was a problem hiding this comment.
Looks like mesh_shape can be removed here
|
|
||
| def test_custom_tile_assignment(self): | ||
| xt = torch.randn(10, 20).to(device=xm.xla_device()) | ||
| mesh_shape = (1, self.n_devices) |
There was a problem hiding this comment.
I see the tests have all devices mapped to a single axis - is there anything stopping us from using e.g. mesh_shape = (2, self.n_devices / 2)?
There was a problem hiding this comment.
No, but for the unit testing a flat mesh is easier to work with since we don't know how many devices we would have (e.g., for CPU, we will have 1).
| def __init__(self, | ||
| device_ids: Union[np.ndarray, List], | ||
| mesh_shape: Tuple[int, ...], | ||
| axis_names: Tuple[str, ...] = None): |
There was a problem hiding this comment.
Just curious - how will axis_names be used long-term? Is it just for annotating the mesh?
There was a problem hiding this comment.
Good question, mesh axis annotation is useful since it makes the annotation logic more readable. We can also build a partitioning rule based on the axis name, instead of int indices.
716865c to
15a30e4
Compare
|
|
3d6bc93 to
eea8e9c
Compare
eea8e9c to
234871f
Compare
This implements
Meshclass from #3871 , to support custom device order in logical XLA device mesh topology.