@@ -49,72 +49,6 @@ def parse_args(parser):
4949 return parser
5050
5151
52- def convert_convinv_1d_to_2d (convinv ):
53- """
54- Takes an invertible 1x1 1-d convolution and returns a 2-d convolution that does
55- the inverse
56- """
57- conv2d = torch .nn .Conv2d (convinv .W_inverse .size (1 ),
58- convinv .W_inverse .size (0 ),
59- 1 , bias = False )
60- conv2d .weight .data [:,:,:,0 ] = convinv .W_inverse .data
61- return conv2d
62-
63-
64- def convert_conv_1d_to_2d (conv1d ):
65- conv2d = torch .nn .Conv2d (conv1d .weight .size (1 ),
66- conv1d .weight .size (0 ),
67- (conv1d .weight .size (2 ), 1 ),
68- stride = (conv1d .stride [0 ], 1 ),
69- dilation = (conv1d .dilation [0 ], 1 ),
70- padding = (conv1d .padding [0 ], 0 ))
71- conv2d .weight .data [:,:,:,0 ] = conv1d .weight .data
72- conv2d .bias .data = conv1d .bias .data
73- return conv2d
74-
75-
76- def convert_WN_1d_to_2d_ (WN ):
77- """
78- Modifies the WaveNet like affine coupling layer in-place to use 2-d convolutions
79- """
80- WN .start = convert_conv_1d_to_2d (WN .start )
81- WN .end = convert_conv_1d_to_2d (WN .end )
82-
83- for i in range (len (WN .in_layers )):
84- WN .in_layers [i ] = convert_conv_1d_to_2d (WN .in_layers [i ])
85-
86- for i in range (len (WN .res_skip_layers )):
87- WN .res_skip_layers [i ] = convert_conv_1d_to_2d (WN .res_skip_layers [i ])
88-
89- for i in range (len (WN .res_skip_layers )):
90- WN .cond_layers [i ] = convert_conv_1d_to_2d (WN .cond_layers [i ])
91-
92- def convert_1d_to_2d_ (glow ):
93- """
94- Caffe2 and TensorRT don't seem to support 1-d convolutions or properly
95- convert ONNX exports with 1d convolutions to 2d convolutions yet, so we
96- do the conversion to 2-d convolutions before ONNX export
97- """
98- # Convert upsample to 2d
99- upsample = torch .nn .ConvTranspose2d (glow .upsample .weight .size (0 ),
100- glow .upsample .weight .size (1 ),
101- (glow .upsample .weight .size (2 ), 1 ),
102- stride = (glow .upsample .stride [0 ], 1 ))
103- upsample .weight .data [:,:,:,0 ] = glow .upsample .weight .data
104- upsample .bias .data = glow .upsample .bias .data
105- glow .upsample = upsample .cuda ()
106-
107- # Convert WN to 2d
108- for WN in glow .WN :
109- convert_WN_1d_to_2d_ (WN )
110-
111- # Convert invertible conv to 2d
112- for i in range (len (glow .convinv )):
113- glow .convinv [i ] = convert_convinv_1d_to_2d (glow .convinv [i ])
114-
115- glow .cuda ()
116-
117-
11852def infer_onnx (self , spect , z , sigma = 0.9 ):
11953
12054 spect = self .upsample (spect )
@@ -126,37 +60,33 @@ def infer_onnx(self, spect, z, sigma=0.9):
12660 mel_dim = 80
12761 batch_size = spect .size (0 )
12862
129- spect = torch .squeeze (spect , 3 )
13063 spect = spect .view ((batch_size , mel_dim , length_spect_group , self .n_group ))
13164 spect = spect .permute (0 , 2 , 1 , 3 )
13265 spect = spect .contiguous ()
13366 spect = spect .view ((batch_size , length_spect_group , self .n_group * mel_dim ))
13467 spect = spect .permute (0 , 2 , 1 )
135- spect = torch .unsqueeze (spect , 3 )
13668 spect = spect .contiguous ()
13769
138- audio = z [:, :self .n_remaining_channels , :, : ]
139- z = z [:, self .n_remaining_channels :self .n_group , :, : ]
70+ audio = z [:, :self .n_remaining_channels , :]
71+ z = z [:, self .n_remaining_channels :self .n_group , :]
14072 audio = sigma * audio
14173
14274 for k in reversed (range (self .n_flows )):
14375 n_half = int (audio .size (1 ) / 2 )
144- audio_0 = audio [:, :n_half , :, : ]
145- audio_1 = audio [:, n_half :(n_half + n_half ), :, : ]
76+ audio_0 = audio [:, :n_half , :]
77+ audio_1 = audio [:, n_half :(n_half + n_half ), :]
14678
14779 output = self .WN [k ]((audio_0 , spect ))
148- s = output [:, n_half :(n_half + n_half ), :, : ]
149- b = output [:, :n_half , :, : ]
80+ s = output [:, n_half :(n_half + n_half ), :]
81+ b = output [:, :n_half , :]
15082 audio_1 = (audio_1 - b ) / torch .exp (s )
15183 audio = torch .cat ([audio_0 , audio_1 ], 1 )
152-
153- audio = self .convinv [k ](audio )
84+ audio = self .convinv [k ].infer (audio )
15485
15586 if k % self .n_early_every == 0 and k > 0 :
156- audio = torch .cat ((z [:, :self .n_early_size , :, : ], audio ), 1 )
157- z = z [:, self .n_early_size :self .n_group , :, : ]
87+ audio = torch .cat ((z [:, :self .n_early_size , :], audio ), 1 )
88+ z = z [:, self .n_early_size :self .n_group , :]
15889
159- audio = torch .squeeze (audio , 3 )
16090 audio = audio .permute (0 ,2 ,1 ).contiguous ().view (batch_size , (length_spect_group * self .n_group ))
16191
16292 return audio
@@ -165,15 +95,15 @@ def infer_onnx(self, spect, z, sigma=0.9):
16595def export_onnx (parser , args ):
16696
16797 waveglow = load_and_setup_model ('WaveGlow' , parser , args .waveglow ,
168- amp_run = args .fp16 , cpu_run = False ,
98+ fp16_run = args .fp16 , cpu_run = False ,
16999 forward_is_infer = False )
170100
171101 # 80 mel channels, 620 mel spectrograms ~ 7 seconds of speech
172102 mel = torch .randn (1 , 80 , 620 ).cuda ()
173103 stride = 256 # value from waveglow upsample
174104 n_group = 8
175105 z_size2 = (mel .size (2 )* stride )// n_group
176- z = torch .randn (1 , n_group , z_size2 , 1 ).cuda ()
106+ z = torch .randn (1 , n_group , z_size2 ).cuda ()
177107
178108 if args .fp16 :
179109 mel = mel .half ()
@@ -183,16 +113,13 @@ def export_onnx(parser, args):
183113 waveglow .infer (mel , sigma = args .sigma_infer )
184114
185115 # export to ONNX
186- convert_1d_to_2d_ (waveglow )
187116 if args .fp16 :
188117 waveglow = waveglow .half ()
189118
190119 fType = types .MethodType
191120 waveglow .forward = fType (infer_onnx , waveglow )
192121
193- mel = mel .unsqueeze (3 )
194-
195- opset_version = 10
122+ opset_version = 12
196123
197124 torch .onnx .export (waveglow , (mel , z ), args .output + "/" + "waveglow.onnx" ,
198125 opset_version = opset_version ,
0 commit comments