Skip to content

Commit b78f6f8

Browse files
committed
Add revision to hub
1 parent d0f4352 commit b78f6f8

2 files changed

Lines changed: 16 additions & 3 deletions

File tree

speechbrain/pretrained/fetching.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def fetch(
3030
overwrite=False,
3131
save_filename=None,
3232
use_auth_token=False,
33+
revision=None,
3334
):
3435
"""Ensures you have a local copy of the file, returns its path
3536
@@ -65,6 +66,10 @@ def fetch(
6566
use_auth_token : bool (default: False)
6667
If true Hugginface's auth_token will be used to load private models from the HuggingFace Hub,
6768
default is False because majority of models are public.
69+
revision : str
70+
The model revision corresponding to the HuggingFace Hub model revision.
71+
This is particularly useful if you wish to pin your code to a particular
72+
version of a model hosted at HuggingFace.
6873
Returns
6974
-------
7075
pathlib.Path
@@ -113,7 +118,10 @@ def fetch(
113118
logger.info(MSG)
114119
try:
115120
fetched_file = huggingface_hub.hf_hub_download(
116-
repo_id=source, filename=filename, use_auth_token=use_auth_token
121+
repo_id=source,
122+
filename=filename,
123+
use_auth_token=use_auth_token,
124+
revision=revision,
117125
)
118126
except HTTPError as e:
119127
if "404 Client Error" in str(e):

speechbrain/pretrained/interfaces.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,7 @@ def from_hparams(
289289
overrides={},
290290
savedir=None,
291291
use_auth_token=False,
292+
revision=None,
292293
**kwargs,
293294
):
294295
"""Fetch and load based from outside source based on HyperPyYAML file
@@ -332,16 +333,20 @@ def from_hparams(
332333
use_auth_token : bool (default: False)
333334
If true Hugginface's auth_token will be used to load private models from the HuggingFace Hub,
334335
default is False because majority of models are public.
336+
revision : str
337+
The model revision corresponding to the HuggingFace Hub model revision.
338+
This is particularly useful if you wish to pin your code to a particular
339+
version of a model hosted at HuggingFace.
335340
"""
336341
if savedir is None:
337342
clsname = cls.__name__
338343
savedir = f"./pretrained_models/{clsname}-{hashlib.md5(source.encode('UTF-8', errors='replace')).hexdigest()}"
339344
hparams_local_path = fetch(
340-
hparams_file, source, savedir, use_auth_token
345+
hparams_file, source, savedir, use_auth_token, revision
341346
)
342347
try:
343348
pymodule_local_path = fetch(
344-
pymodule_file, source, savedir, use_auth_token
349+
pymodule_file, source, savedir, use_auth_token, revision
345350
)
346351
sys.path.append(str(pymodule_local_path.parent))
347352
except ValueError:

0 commit comments

Comments
 (0)