3131import requests
3232import tqdm
3333
34- H5_FILENAME = 'cmu_2020_dfe3e9e0.h5'
35- H5_PATHS = (
36- os .path .join (os .path .dirname (__file__ ), H5_FILENAME ),
37- os .path .join ('~/.dm_control' , H5_FILENAME ),
38- )
39- H5_URL = 'https://storage.googleapis.com/dm_control/cmu_2020_dfe3e9e0.h5'
34+ H5_FILENAME = {'2019' : 'cmu_2019_08756c01.h5' ,
35+ '2020' : 'cmu_2020_dfe3e9e0.h5' }
4036
41- H5_BYTES = 476559420
42- H5_SHA256 = 'dfe3e9e0b08d32960bdafbf89e541339ca8908a9a5e7f4a2c986362890d72863'
37+ H5_PATHS = {k : (os .path .join (os .path .dirname (__file__ ), v ),
38+ os .path .join ('~/.dm_control' , v ))
39+ for k , v in H5_FILENAME .items ()}
40+ H5_URL_BASE = 'https://storage.googleapis.com/dm_control/'
41+ H5_URL = {'2019' : H5_URL_BASE + 'cmu_2019_08756c01.h5' ,
42+ '2020' : H5_URL_BASE + 'cmu_2020_dfe3e9e0.h5' }
4343
44+ H5_BYTES = {'2019' : 488143314 ,
45+ '2020' : 476559420 }
46+ H5_SHA256 = {
47+ '2019' : '08756c01cb4ac20da9918e70e85c32d4880c6c8c16189b02a18b79a5e79afa2b' ,
48+ '2020' : 'dfe3e9e0b08d32960bdafbf89e541339ca8908a9a5e7f4a2c986362890d72863' }
4449
45- def _get_cached_file_path ():
50+
51+ def _get_cached_file_path (version ):
4652 """Returns the path to the cached data file if one exists."""
47- for path in H5_PATHS :
53+ for path in H5_PATHS [ version ] :
4854 expanded_path = os .path .expanduser (path )
4955 try :
50- if os .path .getsize (expanded_path ) != H5_BYTES :
56+ if os .path .getsize (expanded_path ) != H5_BYTES [ version ] :
5157 continue
5258 with open (expanded_path , 'rb' ):
5359 return expanded_path
@@ -56,9 +62,9 @@ def _get_cached_file_path():
5662 return None
5763
5864
59- def _download_and_cache ():
65+ def _download_and_cache (version ):
6066 """Downloads CMU data into one of the candidate paths in H5_PATHS."""
61- for path in H5_PATHS :
67+ for path in H5_PATHS [ version ] :
6268 expanded_path = os .path .expanduser (path )
6369 try :
6470 os .makedirs (os .path .dirname (expanded_path ), exist_ok = True )
@@ -67,18 +73,18 @@ def _download_and_cache():
6773 continue
6874 with f :
6975 try :
70- _download_into_file (f )
76+ _download_into_file (f , version )
7177 except :
7278 os .unlink (expanded_path )
7379 raise
7480 return expanded_path
7581 raise IOError ('cannot open file to write download data into, '
76- f'paths attempted: { H5_PATHS } ' )
82+ f'paths attempted: { H5_PATHS [ version ] } ' )
7783
7884
79- def _download_into_file (f , validate_hash = True ):
85+ def _download_into_file (f , version , validate_hash = True ):
8086 """Download the CMU data into a file object that has been opened for write."""
81- with requests .get (H5_URL , stream = True ) as req :
87+ with requests .get (H5_URL [ version ] , stream = True ) as req :
8288 req .raise_for_status ()
8389 total_bytes = int (req .headers ['Content-Length' ])
8490 progress_bar = tqdm .tqdm (
@@ -94,13 +100,15 @@ def _download_into_file(f, validate_hash=True):
94100
95101 if validate_hash :
96102 f .seek (0 )
97- if hashlib .sha256 (f .read ()).hexdigest () != H5_SHA256 :
103+ if hashlib .sha256 (f .read ()).hexdigest () != H5_SHA256 [ version ] :
98104 raise RuntimeError ('downloaded file is corrupted' )
99105
100106
101- def get_path_for_cmu_2020 ():
102- """Path to mocap data fitted to the 2020 version of the CMU Humanoid model."""
103- path = _get_cached_file_path ()
107+ def get_path_for_cmu (version = '2020' ):
108+ """Path to mocap data fitted to a version of the CMU Humanoid model."""
109+ assert version in H5_FILENAME .keys ()
110+ path = _get_cached_file_path (version )
104111 if path is None :
105- path = _download_and_cache ()
112+ path = _download_and_cache (version )
106113 return path
114+
0 commit comments