11import urllib .request
22import torch
3+ import os
4+ import sys
35
46# from https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/Tacotron2/inference.py
57def checkpoint_from_distributed (state_dict ):
@@ -54,6 +56,7 @@ def nvidia_ncf(pretrained=True, **kwargs):
5456 from PyTorch .Recommendation .NCF import neumf as ncf
5557
5658 fp16 = "model_math" in kwargs and kwargs ["model_math" ] == "fp16"
59+ force_reload = "force_reload" in kwargs and kwargs ["force_reload" ]
5760
5861 config = {'nb_users' : None , 'nb_items' : None , 'mf_dim' : 64 , 'mf_reg' : 0. ,
5962 'mlp_layer_sizes' : [256 , 256 , 128 , 64 ], 'mlp_layer_regs' :[0 , 0 , 0 , 0 ], 'dropout' : 0.5 }
@@ -63,8 +66,10 @@ def nvidia_ncf(pretrained=True, **kwargs):
6366 checkpoint = 'https://developer.nvidia.com/joc-ncf-fp16-pyt-20190225'
6467 else :
6568 checkpoint = 'https://developer.nvidia.com/joc-ncf-fp32-pyt-20190225'
66- ckpt_file = "ncf_ckpt.pt"
67- urllib .request .urlretrieve (checkpoint , ckpt_file )
69+ ckpt_file = os .path .basename (checkpoint )
70+ if not os .path .exists (ckpt_file ) or force_reload :
71+ sys .stderr .write ('Downloading checkpoint from {}\n ' .format (checkpoint ))
72+ urllib .request .urlretrieve (checkpoint , ckpt_file )
6873 ckpt = torch .load (ckpt_file )
6974
7075 if checkpoint_from_distributed (ckpt ):
@@ -117,14 +122,17 @@ def nvidia_tacotron2(pretrained=True, **kwargs):
117122 from PyTorch .SpeechSynthesis .Tacotron2 .models import lstmcell_to_float , batchnorm_to_float
118123
119124 fp16 = "model_math" in kwargs and kwargs ["model_math" ] == "fp16"
125+ force_reload = "force_reload" in kwargs and kwargs ["force_reload" ]
120126
121127 if pretrained :
122128 if fp16 :
123129 checkpoint = 'https://developer.nvidia.com/joc-tacotron2-fp16-pyt-20190306'
124130 else :
125131 checkpoint = 'https://developer.nvidia.com/joc-tacotron2-fp32-pyt-20190306'
126- ckpt_file = "tacotron2_ckpt.pt"
127- urllib .request .urlretrieve (checkpoint , ckpt_file )
132+ ckpt_file = os .path .basename (checkpoint )
133+ if not os .path .exists (ckpt_file ) or force_reload :
134+ sys .stderr .write ('Downloading checkpoint from {}\n ' .format (checkpoint ))
135+ urllib .request .urlretrieve (checkpoint , ckpt_file )
128136 ckpt = torch .load (ckpt_file )
129137 state_dict = ckpt ['state_dict' ]
130138 if checkpoint_from_distributed (state_dict ):
@@ -172,14 +180,17 @@ def nvidia_waveglow(pretrained=True, **kwargs):
172180 from PyTorch .SpeechSynthesis .Tacotron2 .models import batchnorm_to_float
173181
174182 fp16 = "model_math" in kwargs and kwargs ["model_math" ] == "fp16"
183+ force_reload = "force_reload" in kwargs and kwargs ["force_reload" ]
175184
176185 if pretrained :
177186 if fp16 :
178187 checkpoint = 'https://developer.nvidia.com/joc-waveglow-fp16-pyt-20190306'
179188 else :
180189 checkpoint = 'https://developer.nvidia.com/joc-waveglow-fp32-pyt-20190306'
181- ckpt_file = "waveglow_ckpt.pt"
182- urllib .request .urlretrieve (checkpoint , ckpt_file )
190+ ckpt_file = os .path .basename (checkpoint )
191+ if not os .path .exists (ckpt_file ) or force_reload :
192+ sys .stderr .write ('Downloading checkpoint from {}\n ' .format (checkpoint ))
193+ urllib .request .urlretrieve (checkpoint , ckpt_file )
183194 ckpt = torch .load (ckpt_file )
184195 state_dict = ckpt ['state_dict' ]
185196 if checkpoint_from_distributed (state_dict ):
0 commit comments