Skip to content

Commit 161a4ea

Browse files
committed
caching of pre-trained weights added to entrypoints
1 parent 14fe91f commit 161a4ea

1 file changed

Lines changed: 17 additions & 6 deletions

File tree

hubconf.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import urllib.request
22
import torch
3+
import os
4+
import sys
35

46
# from https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/Tacotron2/inference.py
57
def checkpoint_from_distributed(state_dict):
@@ -54,6 +56,7 @@ def nvidia_ncf(pretrained=True, **kwargs):
5456
from PyTorch.Recommendation.NCF import neumf as ncf
5557

5658
fp16 = "model_math" in kwargs and kwargs["model_math"] == "fp16"
59+
force_reload = "force_reload" in kwargs and kwargs["force_reload"]
5760

5861
config = {'nb_users': None, 'nb_items': None, 'mf_dim': 64, 'mf_reg': 0.,
5962
'mlp_layer_sizes': [256, 256, 128, 64], 'mlp_layer_regs':[0, 0, 0, 0], 'dropout': 0.5}
@@ -63,8 +66,10 @@ def nvidia_ncf(pretrained=True, **kwargs):
6366
checkpoint = 'https://developer.nvidia.com/joc-ncf-fp16-pyt-20190225'
6467
else:
6568
checkpoint = 'https://developer.nvidia.com/joc-ncf-fp32-pyt-20190225'
66-
ckpt_file = "ncf_ckpt.pt"
67-
urllib.request.urlretrieve(checkpoint, ckpt_file)
69+
ckpt_file = os.path.basename(checkpoint)
70+
if not os.path.exists(ckpt_file) or force_reload:
71+
sys.stderr.write('Downloading checkpoint from {}\n'.format(checkpoint))
72+
urllib.request.urlretrieve(checkpoint, ckpt_file)
6873
ckpt = torch.load(ckpt_file)
6974

7075
if checkpoint_from_distributed(ckpt):
@@ -117,14 +122,17 @@ def nvidia_tacotron2(pretrained=True, **kwargs):
117122
from PyTorch.SpeechSynthesis.Tacotron2.models import lstmcell_to_float, batchnorm_to_float
118123

119124
fp16 = "model_math" in kwargs and kwargs["model_math"] == "fp16"
125+
force_reload = "force_reload" in kwargs and kwargs["force_reload"]
120126

121127
if pretrained:
122128
if fp16:
123129
checkpoint = 'https://developer.nvidia.com/joc-tacotron2-fp16-pyt-20190306'
124130
else:
125131
checkpoint = 'https://developer.nvidia.com/joc-tacotron2-fp32-pyt-20190306'
126-
ckpt_file = "tacotron2_ckpt.pt"
127-
urllib.request.urlretrieve(checkpoint, ckpt_file)
132+
ckpt_file = os.path.basename(checkpoint)
133+
if not os.path.exists(ckpt_file) or force_reload:
134+
sys.stderr.write('Downloading checkpoint from {}\n'.format(checkpoint))
135+
urllib.request.urlretrieve(checkpoint, ckpt_file)
128136
ckpt = torch.load(ckpt_file)
129137
state_dict = ckpt['state_dict']
130138
if checkpoint_from_distributed(state_dict):
@@ -172,14 +180,17 @@ def nvidia_waveglow(pretrained=True, **kwargs):
172180
from PyTorch.SpeechSynthesis.Tacotron2.models import batchnorm_to_float
173181

174182
fp16 = "model_math" in kwargs and kwargs["model_math"] == "fp16"
183+
force_reload = "force_reload" in kwargs and kwargs["force_reload"]
175184

176185
if pretrained:
177186
if fp16:
178187
checkpoint = 'https://developer.nvidia.com/joc-waveglow-fp16-pyt-20190306'
179188
else:
180189
checkpoint = 'https://developer.nvidia.com/joc-waveglow-fp32-pyt-20190306'
181-
ckpt_file = "waveglow_ckpt.pt"
182-
urllib.request.urlretrieve(checkpoint, ckpt_file)
190+
ckpt_file = os.path.basename(checkpoint)
191+
if not os.path.exists(ckpt_file) or force_reload:
192+
sys.stderr.write('Downloading checkpoint from {}\n'.format(checkpoint))
193+
urllib.request.urlretrieve(checkpoint, ckpt_file)
183194
ckpt = torch.load(ckpt_file)
184195
state_dict = ckpt['state_dict']
185196
if checkpoint_from_distributed(state_dict):

0 commit comments

Comments
 (0)