55import math
66import argparse
77from benchmarks .communication .constants import *
8+ from deepspeed .accelerator import get_accelerator
89
910global dist
1011
@@ -14,7 +15,7 @@ def init_torch_distributed(backend):
1415 import torch .distributed as dist
1516 torch .distributed .init_process_group (backend )
1617 local_rank = int (os .environ ['LOCAL_RANK' ])
17- torch . cuda .set_device (local_rank )
18+ get_accelerator () .set_device (local_rank )
1819
1920
2021def init_deepspeed_comm (backend ):
@@ -23,7 +24,7 @@ def init_deepspeed_comm(backend):
2324 import deepspeed .comm as dist
2425 deepspeed .init_distributed (dist_backend = backend )
2526 local_rank = int (os .environ ['LOCAL_RANK' ])
26- torch . cuda .set_device (local_rank )
27+ get_accelerator () .set_device (local_rank )
2728
2829
2930def init_processes (local_rank , args ):
@@ -101,14 +102,13 @@ def get_metric_strings(args, tput, busbw, duration):
101102
102103
103104def sync_all ():
104- torch . cuda .synchronize ()
105+ get_accelerator () .synchronize ()
105106 dist .barrier ()
106107
107108
108109def max_numel (comm_op , dtype , mem_factor , local_rank , args ):
109110 dtype_size = _element_size (dtype )
110- max_memory_per_gpu = torch .cuda .get_device_properties (
111- local_rank ).total_memory * mem_factor
111+ max_memory_per_gpu = get_accelerator ().total_memory (local_rank ) * mem_factor
112112 if comm_op == 'all_reduce' or comm_op == 'pt2pt' or comm_op == 'broadcast' :
113113 elements_per_gpu = int (max_memory_per_gpu // dtype_size )
114114 elif comm_op == 'all_gather' :
@@ -185,7 +185,8 @@ def benchmark_parser():
185185 parser .add_argument ("--backend" ,
186186 type = str ,
187187 default = DEFAULT_BACKEND ,
188- choices = ['nccl' ],
188+ choices = ['nccl' ,
189+ 'ccl' ],
189190 help = 'Communication library to use' )
190191 parser .add_argument ("--dist" ,
191192 type = str ,
0 commit comments