Skip to content

Commit 8b0966e

Browse files
committed
HTTPStorage class; download-checkpoint now uses generic Storage
1 parent 8a136fc commit 8b0966e

File tree

3 files changed

+68
-21
lines changed

3 files changed

+68
-21
lines changed

download-checkpoint.py

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,14 @@
11
import os
2-
import requests
3-
from tqdm import tqdm
2+
from utils import Storage
43

54
CHECKPOINT_URL = os.environ.get("CHECKPOINT_URL", None)
65

7-
8-
def download(url: str, fname: str):
9-
resp = requests.get(url, stream=True)
10-
total = int(resp.headers.get("content-length", 0))
11-
# Can also replace 'file' with a io.BytesIO object
12-
with open(fname, "wb") as file, tqdm(
13-
desc="Downloading",
14-
total=total,
15-
unit="iB",
16-
unit_scale=True,
17-
unit_divisor=1024,
18-
) as bar:
19-
for data in resp.iter_content(chunk_size=1024):
20-
size = file.write(data)
21-
bar.update(size)
22-
23-
246
if __name__ == "__main__":
257
if CHECKPOINT_URL:
268
CHECKPOINT_DIR = "/root/.cache/checkpoints"
279
if not os.path.isdir(CHECKPOINT_DIR):
2810
os.makedirs(CHECKPOINT_DIR)
2911
fname = CHECKPOINT_DIR + "/" + CHECKPOINT_URL.split("/").pop()
3012
if not os.path.isfile(fname):
31-
print("Downloading " + CHECKPOINT_URL)
32-
download(url=CHECKPOINT_URL, fname=fname)
13+
storage = Storage(CHECKPOINT_URL)
14+
storage.download_file(fname)

utils/storage/HTTPStorage.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import re
2+
import os
3+
import time
4+
import requests
5+
from tqdm import tqdm
6+
7+
8+
def get_now():
9+
return round(time.time() * 1000)
10+
11+
12+
class HTTPStorage:
13+
def __init__(self, url, path=""):
14+
self.url = url
15+
16+
if url.startswith("s3://"):
17+
url = "https://" + url[5:]
18+
elif url.startswith("http+s3://"):
19+
url = "http" + url[7:]
20+
elif url.startswith("https+s3://"):
21+
url = "https" + url[8:]
22+
23+
s3_dest = re.match(
24+
"^(?P<endpoint>https?://[^/]*)(/(?P<bucket>[^/]+))?(/(?P<path>.*))?$",
25+
url,
26+
).groupdict()
27+
28+
if not s3_dest["endpoint"] or s3_dest["endpoint"].endswith("//"):
29+
s3_dest["endpoint"] = AWS_S3_ENDPOINT_URL
30+
if not s3_dest["bucket"]:
31+
s3_dest["bucket"] = AWS_S3_DEFAULT_BUCKET
32+
if not s3_dest["path"] or s3_dest["path"] == "":
33+
s3_dest["path"] = path
34+
35+
self.endpoint_url = s3_dest["endpoint"]
36+
self.bucket_name = s3_dest["bucket"]
37+
self.path = s3_dest["path"]
38+
39+
self._s3 = None
40+
self._bucket = None
41+
print("self.endpoint_url", self.endpoint_url)
42+
43+
def upload_file(self, source, dest):
44+
raise RuntimeError("HTTP PUT not implemented yet")
45+
46+
def download_file(self, fname):
47+
print(f"Downloading {self.url} to {fname}...")
48+
resp = requests.get(self.url, stream=True)
49+
total = int(resp.headers.get("content-length", 0))
50+
# Can also replace 'file' with a io.BytesIO object
51+
with open(fname, "wb") as file, tqdm(
52+
desc="Downloading",
53+
total=total,
54+
unit="iB",
55+
unit_scale=True,
56+
unit_divisor=1024,
57+
) as bar:
58+
for data in resp.iter_content(chunk_size=1024):
59+
size = file.write(data)
60+
bar.update(size)

utils/storage/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
import os
22
import re
33
from .S3Storage import S3Storage
4+
from .HTTPStorage import HTTPStorage
45

56

67
def Storage(url):
78
if re.search("^(https?\+)?s3://", url):
89
return S3Storage(url)
10+
11+
if re.search("^https?://", url):
12+
return HTTPStorage(url)
13+
914
raise RuntimeError("No storage handler for: " + url)

0 commit comments

Comments
 (0)