Skip to content

Commit b04fedd

Browse files
authored
Enable cpu/xpu support for the benchmarking suite (deepspeedai#905)
* enable cpu/xpu support for the benchmarking suite * fixes according to review feedback
1 parent bbab278 commit b04fedd

8 files changed

Lines changed: 76 additions & 11 deletions

File tree

benchmarks/communication/README.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# The DeepSpeed Communication Benchmarking Suite
22

3-
The intent of these benchmarks is to measure communication latency/bw of deepspeed and/or pytorch distributed communication operations at the Python layer. These benchmarks are complementary to C-level comms benchmarks like [OSU Micro-Benchmarks](https://mvapich.cse.ohio-state.edu/benchmarks/) and [NCCL Tests](https://github.com/NVIDIA/nccl-tests) in that users can:
3+
The intent of these benchmarks is to measure communication latency/bw of deepspeed and/or pytorch distributed communication operations at the Python layer. These benchmarks are complementary to C-level comms benchmarks like [OSU Micro-Benchmarks](https://mvapich.cse.ohio-state.edu/benchmarks/) , [NCCL Tests](https://github.com/NVIDIA/nccl-tests) and [oneCCL Benchmark](https://oneapi-src.github.io/oneCCL/benchmark.html) in that users can:
44
- Easily debug which layer of the communication software stack hangs or performance degradations originate from.
55
- Measure the expected communication performance of either DeepSpeed comms or pure PyTorch distributed
66

@@ -77,6 +77,14 @@ Finally, users can choose specific communication operations to run in `run_all.p
7777
deepspeed run_all.py --scan --all-reduce --all-to-all --broadcast
7878
</pre>
7979

80+
## CPU Support
81+
Those benchmarks could also support other devices like Intel CPU via oneCCL.
82+
Users just need to append one more argument "--device cpu" for all python scripts to run on Intel CPU.
83+
For example, run with a single large message size on Intel CPU:
84+
<pre>
85+
deepspeed all_reduce.py --device cpu
86+
</pre>
87+
8088

8189
# Adding Communication Benchmarks
8290

benchmarks/communication/all_gather.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717

1818
# Run all_gather and print metrics
1919
def timed_all_gather(input, output, start_event, end_event, args):
20+
if args.device == "cpu":
21+
print_rank_0(f"No Event support on CPU to measure time for now")
22+
return
2023
if args.dist == 'torch':
2124
import torch.distributed as dist
2225

@@ -64,8 +67,15 @@ def run_all_gather(local_rank, args):
6467
global_rank = dist.get_rank()
6568
world_size = dist.get_world_size()
6669

67-
start_event = torch.cuda.Event(enable_timing=True)
68-
end_event = torch.cuda.Event(enable_timing=True)
70+
if args.device == "xpu":
71+
start_event = torch.xpu.Event(enable_timing=True)
72+
end_event = torch.xpu.Event(enable_timing=True)
73+
elif args.device == "cpu":
74+
start_event = torch.cpu.Event()
75+
end_event = torch.cpu.Event()
76+
else:
77+
start_event = torch.cuda.Event(enable_timing=True)
78+
end_event = torch.cuda.Event(enable_timing=True)
6979

7080
if args.scan:
7181
# Create list of message sizes

benchmarks/communication/all_reduce.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515

1616

1717
def timed_all_reduce(input, start_event, end_event, args):
18+
if args.device == "cpu":
19+
print_rank_0(f"No Event support on CPU to measure time for now")
20+
return
1821
if args.dist == 'torch':
1922
import torch.distributed as dist
2023
elif args.dist == 'deepspeed':
@@ -60,8 +63,15 @@ def run_all_reduce(local_rank, args):
6063
world_size = dist.get_world_size()
6164
global_rank = dist.get_rank()
6265

63-
start_event = torch.cuda.Event(enable_timing=True)
64-
end_event = torch.cuda.Event(enable_timing=True)
66+
if args.device == "xpu":
67+
start_event = torch.xpu.Event(enable_timing=True)
68+
end_event = torch.xpu.Event(enable_timing=True)
69+
elif args.device == "cpu":
70+
start_event = torch.cpu.Event()
71+
end_event = torch.cpu.Event()
72+
else:
73+
start_event = torch.cuda.Event(enable_timing=True)
74+
end_event = torch.cuda.Event(enable_timing=True)
6575

6676
if args.scan:
6777
M_LIST = []

benchmarks/communication/all_to_all.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515

1616

1717
def timed_all_to_all(input, output, start_event, end_event, args):
18+
if args.device == "cpu":
19+
print_rank_0(f"No Event support on CPU to measure time for now")
20+
return
1821
if args.dist == 'torch':
1922
import torch.distributed as dist
2023
elif args.dist == 'deepspeed':
@@ -59,8 +62,15 @@ def run_all_to_all(local_rank, args):
5962
# Prepare benchmark header
6063
print_header(args, 'all_to_all')
6164

62-
start_event = torch.cuda.Event(enable_timing=True)
63-
end_event = torch.cuda.Event(enable_timing=True)
65+
if args.device == "xpu":
66+
start_event = torch.xpu.Event(enable_timing=True)
67+
end_event = torch.xpu.Event(enable_timing=True)
68+
elif args.device == "cpu":
69+
start_event = torch.cpu.Event()
70+
end_event = torch.cpu.Event()
71+
else:
72+
start_event = torch.cuda.Event(enable_timing=True)
73+
end_event = torch.cuda.Event(enable_timing=True)
6474

6575
if args.scan:
6676
M_LIST = []

benchmarks/communication/broadcast.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515

1616

1717
def timed_broadcast(input, start_event, end_event, args):
18+
if args.device == "cpu":
19+
print_rank_0(f"No Event support on CPU to measure time for now")
20+
return
1821
if args.dist == 'torch':
1922
import torch.distributed as dist
2023
elif args.dist == 'deepspeed':
@@ -60,8 +63,15 @@ def run_broadcast(local_rank, args):
6063
world_size = dist.get_world_size()
6164
global_rank = dist.get_rank()
6265

63-
start_event = torch.cuda.Event(enable_timing=True)
64-
end_event = torch.cuda.Event(enable_timing=True)
66+
if args.device == "xpu":
67+
start_event = torch.xpu.Event(enable_timing=True)
68+
end_event = torch.xpu.Event(enable_timing=True)
69+
elif args.device == "cpu":
70+
start_event = torch.cpu.Event()
71+
end_event = torch.cpu.Event()
72+
else:
73+
start_event = torch.cuda.Event(enable_timing=True)
74+
end_event = torch.cuda.Event(enable_timing=True)
6575

6676
if args.scan:
6777
M_LIST = []

benchmarks/communication/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,5 @@
1212
DEFAULT_UNIT = 'Gbps'
1313
DEFAULT_DIST = 'deepspeed'
1414
DEFAULT_MAXSIZE = 24
15+
DEFAULT_DEVICE = 'cuda'
1516
TORCH_DISTRIBUTED_DEFAULT_PORT = 29500

benchmarks/communication/pt2pt.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515

1616

1717
def timed_pt2pt(input, start_event, end_event, args):
18+
if args.device == "cpu":
19+
print_rank_0(f"No Event support on CPU to measure time for now")
20+
return
1821
if args.dist == 'torch':
1922
import torch.distributed as dist
2023
elif args.dist == 'deepspeed':
@@ -78,8 +81,15 @@ def run_pt2pt(local_rank, args):
7881
global_rank = dist.get_rank()
7982
world_size = dist.get_world_size()
8083

81-
start_event = torch.cuda.Event(enable_timing=True)
82-
end_event = torch.cuda.Event(enable_timing=True)
84+
if args.device == "xpu":
85+
start_event = torch.xpu.Event(enable_timing=True)
86+
end_event = torch.xpu.Event(enable_timing=True)
87+
elif args.device == "cpu":
88+
start_event = torch.cpu.Event()
89+
end_event = torch.cpu.Event()
90+
else:
91+
start_event = torch.cuda.Event(enable_timing=True)
92+
end_event = torch.cuda.Event(enable_timing=True)
8393

8494
if args.scan:
8595
# Create list of message sizes

benchmarks/communication/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,11 @@ def get_bw(comm_op, size, duration, args):
108108
n = dist.get_world_size()
109109
tput = 0
110110
busbw = 0
111+
112+
if duration == 0:
113+
print_rank_0("Error. Duration is 0.")
114+
return tput, busbw
115+
111116
if comm_op == "all_to_all":
112117
tput = (size / duration)
113118
busbw = (size / duration) * ((n - 1) / n)
@@ -235,4 +240,5 @@ def benchmark_parser():
235240
default=.3,
236241
help='Proportion of max available GPU memory to use for single-size evals')
237242
parser.add_argument("--debug", action="store_true", help='Enables all_to_all debug prints')
243+
parser.add_argument("--device", type=str, default=DEFAULT_DEVICE, help='target device')
238244
return parser

0 commit comments

Comments
 (0)