Skip to content

Commit 6719b46

Browse files
authored
fix typo when getting kernel dim in conv calculation (deepspeedai#1989)
1 parent b6f2a56 commit 6719b46

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

deepspeed/profiling/flops_profiler/profiler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -534,7 +534,7 @@ def _conv_flops_compute(input,
534534
batch_size = input.shape[0]
535535
in_channels = input.shape[1]
536536
out_channels = weight.shape[0]
537-
kernel_dims = list(weight.shape[-2:])
537+
kernel_dims = list(weight.shape[2:])
538538
input_dims = list(input.shape[2:])
539539

540540
length = len(input_dims)
@@ -575,7 +575,7 @@ def _conv_trans_flops_compute(
575575
batch_size = input.shape[0]
576576
in_channels = input.shape[1]
577577
out_channels = weight.shape[0]
578-
kernel_dims = list(weight.shape[-2:])
578+
kernel_dims = list(weight.shape[2:])
579579
input_dims = list(input.shape[2:])
580580

581581
length = len(input_dims)

0 commit comments

Comments
 (0)