Skip to content

Commit 4a8e6fb

Browse files
committed
Cleanup
1 parent 6b2d6ff commit 4a8e6fb

12 files changed

Lines changed: 205 additions & 533 deletions

File tree

pyrogram/client/client.py

Lines changed: 120 additions & 147 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
#
1616
# You should have received a copy of the GNU Lesser General Public License
1717
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
18-
import io
18+
1919
import logging
2020
import math
2121
import os
@@ -1231,9 +1231,9 @@ def download_worker(self):
12311231

12321232
temp_file_path = ""
12331233
final_file_path = ""
1234-
path = [None]
1234+
12351235
try:
1236-
data, done, progress, progress_args, out, path, to_file = packet
1236+
data, directory, file_name, done, progress, progress_args, path = packet
12371237

12381238
temp_file_path = self.get_file(
12391239
media_type=data.media_type,
@@ -1250,15 +1250,13 @@ def download_worker(self):
12501250
file_size=data.file_size,
12511251
is_big=data.is_big,
12521252
progress=progress,
1253-
progress_args=progress_args,
1254-
out=out
1253+
progress_args=progress_args
12551254
)
1256-
if to_file:
1257-
final_file_path = out.name
1258-
else:
1259-
final_file_path = ''
1260-
if to_file:
1261-
out.close()
1255+
1256+
if temp_file_path:
1257+
final_file_path = os.path.abspath(re.sub("\\\\", "/", os.path.join(directory, file_name)))
1258+
os.makedirs(directory, exist_ok=True)
1259+
shutil.move(temp_file_path, final_file_path)
12621260
except Exception as e:
12631261
log.error(e, exc_info=True)
12641262

@@ -1715,7 +1713,7 @@ def resolve_peer(self, peer_id: Union[int, str]):
17151713

17161714
def save_file(
17171715
self,
1718-
path: Union[str, io.IOBase],
1716+
path: str,
17191717
file_id: int = None,
17201718
file_part: int = 0,
17211719
progress: callable = None,
@@ -1767,20 +1765,9 @@ def save_file(
17671765
17681766
Raises:
17691767
RPCError: In case of a Telegram RPC error.
1770-
ValueError: if path is not str or file-like readable object
17711768
"""
17721769
part_size = 512 * 1024
1773-
if isinstance(path, str):
1774-
fp = open(path, 'rb')
1775-
filename = os.path.basename(path)
1776-
elif hasattr(path, 'write'):
1777-
fp = path
1778-
filename = fp.name
1779-
else:
1780-
raise ValueError("Invalid path passed! Pass file pointer or path to file")
1781-
fp.seek(0, os.SEEK_END)
1782-
file_size = fp.tell()
1783-
fp.seek(0)
1770+
file_size = os.path.getsize(path)
17841771

17851772
if file_size == 0:
17861773
raise ValueError("File size equals to 0 B")
@@ -1798,74 +1785,67 @@ def save_file(
17981785
session.start()
17991786

18001787
try:
1801-
fp.seek(part_size * file_part)
1788+
with open(path, "rb") as f:
1789+
f.seek(part_size * file_part)
18021790

1803-
while True:
1804-
chunk = fp.read(part_size)
1805-
1806-
if not chunk:
1807-
if not is_big:
1808-
md5_sum = "".join([hex(i)[2:].zfill(2) for i in md5_sum.digest()])
1809-
break
1810-
1811-
for _ in range(3):
1812-
if is_big:
1813-
rpc = functions.upload.SaveBigFilePart(
1814-
file_id=file_id,
1815-
file_part=file_part,
1816-
file_total_parts=file_total_parts,
1817-
bytes=chunk
1818-
)
1819-
else:
1820-
rpc = functions.upload.SaveFilePart(
1821-
file_id=file_id,
1822-
file_part=file_part,
1823-
bytes=chunk
1824-
)
1791+
while True:
1792+
chunk = f.read(part_size)
18251793

1826-
if session.send(rpc):
1794+
if not chunk:
1795+
if not is_big:
1796+
md5_sum = "".join([hex(i)[2:].zfill(2) for i in md5_sum.digest()])
18271797
break
1828-
else:
1829-
raise AssertionError("Telegram didn't accept chunk #{} of {}".format(file_part, path))
18301798

1831-
if is_missing_part:
1832-
return
1799+
for _ in range(3):
1800+
if is_big:
1801+
rpc = functions.upload.SaveBigFilePart(
1802+
file_id=file_id,
1803+
file_part=file_part,
1804+
file_total_parts=file_total_parts,
1805+
bytes=chunk
1806+
)
1807+
else:
1808+
rpc = functions.upload.SaveFilePart(
1809+
file_id=file_id,
1810+
file_part=file_part,
1811+
bytes=chunk
1812+
)
1813+
1814+
if session.send(rpc):
1815+
break
1816+
else:
1817+
raise AssertionError("Telegram didn't accept chunk #{} of {}".format(file_part, path))
1818+
1819+
if is_missing_part:
1820+
return
18331821

1834-
if not is_big:
1835-
md5_sum.update(chunk)
1822+
if not is_big:
1823+
md5_sum.update(chunk)
18361824

1837-
file_part += 1
1825+
file_part += 1
18381826

1839-
if progress:
1840-
progress(min(file_part * part_size, file_size), file_size, *progress_args)
1827+
if progress:
1828+
progress(min(file_part * part_size, file_size), file_size, *progress_args)
18411829
except Client.StopTransmission:
1842-
if isinstance(path, str):
1843-
fp.close()
18441830
raise
18451831
except Exception as e:
1846-
if isinstance(path, str):
1847-
fp.close()
18481832
log.error(e, exc_info=True)
18491833
else:
1850-
if isinstance(path, str):
1851-
fp.close()
18521834
if is_big:
18531835
return types.InputFileBig(
18541836
id=file_id,
18551837
parts=file_total_parts,
1856-
name=filename,
1838+
name=os.path.basename(path),
18571839

18581840
)
18591841
else:
18601842
return types.InputFile(
18611843
id=file_id,
18621844
parts=file_total_parts,
1863-
name=filename,
1845+
name=os.path.basename(path),
18641846
md5_checksum=md5_sum
18651847
)
18661848
finally:
1867-
if isinstance(path, str):
1868-
fp.close()
18691849
session.stop()
18701850

18711851
def get_file(
@@ -1884,8 +1864,7 @@ def get_file(
18841864
file_size: int,
18851865
is_big: bool,
18861866
progress: callable,
1887-
progress_args: tuple = (),
1888-
out: io.IOBase = None
1867+
progress_args: tuple = ()
18891868
) -> str:
18901869
with self.media_sessions_lock:
18911870
session = self.media_sessions.get(dc_id, None)
@@ -1971,10 +1950,7 @@ def get_file(
19711950
limit = 1024 * 1024
19721951
offset = 0
19731952
file_name = ""
1974-
if not out:
1975-
f = tempfile.NamedTemporaryFile("wb", delete=False)
1976-
else:
1977-
f = out
1953+
19781954
try:
19791955
r = session.send(
19801956
functions.upload.GetFile(
@@ -1985,36 +1961,35 @@ def get_file(
19851961
)
19861962

19871963
if isinstance(r, types.upload.File):
1988-
if hasattr(f, "name"):
1964+
with tempfile.NamedTemporaryFile("wb", delete=False) as f:
19891965
file_name = f.name
19901966

1991-
while True:
1992-
chunk = r.bytes
1993-
1994-
if not chunk:
1995-
break
1967+
while True:
1968+
chunk = r.bytes
19961969

1997-
f.write(chunk)
1970+
if not chunk:
1971+
break
19981972

1999-
offset += limit
1973+
f.write(chunk)
20001974

2001-
if progress:
2002-
progress(
1975+
offset += limit
20031976

2004-
min(offset, file_size)
2005-
if file_size != 0
2006-
else offset,
2007-
file_size,
2008-
*progress_args
2009-
)
1977+
if progress:
1978+
progress(
1979+
min(offset, file_size)
1980+
if file_size != 0
1981+
else offset,
1982+
file_size,
1983+
*progress_args
1984+
)
20101985

2011-
r = session.send(
2012-
functions.upload.GetFile(
2013-
location=location,
2014-
offset=offset,
2015-
limit=limit
1986+
r = session.send(
1987+
functions.upload.GetFile(
1988+
location=location,
1989+
offset=offset,
1990+
limit=limit
1991+
)
20161992
)
2017-
)
20181993

20191994
elif isinstance(r, types.upload.FileCdnRedirect):
20201995
with self.media_sessions_lock:
@@ -2028,80 +2003,78 @@ def get_file(
20282003
self.media_sessions[r.dc_id] = cdn_session
20292004

20302005
try:
2031-
if hasattr(f, "name"):
2006+
with tempfile.NamedTemporaryFile("wb", delete=False) as f:
20322007
file_name = f.name
20332008

2034-
while True:
2035-
r2 = cdn_session.send(
2036-
functions.upload.GetCdnFile(
2037-
file_token=r.file_token,
2038-
offset=offset,
2039-
limit=limit
2009+
while True:
2010+
r2 = cdn_session.send(
2011+
functions.upload.GetCdnFile(
2012+
file_token=r.file_token,
2013+
offset=offset,
2014+
limit=limit
2015+
)
20402016
)
2041-
)
20422017

2043-
if isinstance(r2, types.upload.CdnFileReuploadNeeded):
2044-
try:
2045-
session.send(
2046-
functions.upload.ReuploadCdnFile(
2047-
file_token=r.file_token,
2048-
request_token=r2.request_token
2018+
if isinstance(r2, types.upload.CdnFileReuploadNeeded):
2019+
try:
2020+
session.send(
2021+
functions.upload.ReuploadCdnFile(
2022+
file_token=r.file_token,
2023+
request_token=r2.request_token
2024+
)
20492025
)
2050-
)
2051-
except VolumeLocNotFound:
2052-
break
2053-
else:
2054-
continue
2026+
except VolumeLocNotFound:
2027+
break
2028+
else:
2029+
continue
20552030

2056-
chunk = r2.bytes
2031+
chunk = r2.bytes
20572032

2058-
# https://core.telegram.org/cdn#decrypting-files
2059-
decrypted_chunk = AES.ctr256_decrypt(
2060-
chunk,
2061-
r.encryption_key,
2062-
bytearray(
2063-
r.encryption_iv[:-4]
2064-
+ (offset // 16).to_bytes(4, "big")
2033+
# https://core.telegram.org/cdn#decrypting-files
2034+
decrypted_chunk = AES.ctr256_decrypt(
2035+
chunk,
2036+
r.encryption_key,
2037+
bytearray(
2038+
r.encryption_iv[:-4]
2039+
+ (offset // 16).to_bytes(4, "big")
2040+
)
20652041
)
2066-
)
20672042

2068-
hashes = session.send(
2069-
functions.upload.GetCdnFileHashes(
2070-
file_token=r.file_token,
2071-
offset=offset
2043+
hashes = session.send(
2044+
functions.upload.GetCdnFileHashes(
2045+
file_token=r.file_token,
2046+
offset=offset
2047+
)
20722048
)
2073-
)
20742049

2075-
# https://core.telegram.org/cdn#verifying-files
2076-
for i, h in enumerate(hashes):
2077-
cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)]
2078-
assert h.hash == sha256(cdn_chunk).digest(), "Invalid CDN hash part {}".format(i)
2050+
# https://core.telegram.org/cdn#verifying-files
2051+
for i, h in enumerate(hashes):
2052+
cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)]
2053+
assert h.hash == sha256(cdn_chunk).digest(), "Invalid CDN hash part {}".format(i)
20792054

2080-
f.write(decrypted_chunk)
2055+
f.write(decrypted_chunk)
20812056

2082-
offset += limit
2083-
2084-
if progress:
2085-
progress(
2057+
offset += limit
20862058

2087-
min(offset, file_size)
2088-
if file_size != 0
2089-
else offset,
2090-
file_size,
2091-
*progress_args
2092-
)
2059+
if progress:
2060+
progress(
2061+
min(offset, file_size)
2062+
if file_size != 0
2063+
else offset,
2064+
file_size,
2065+
*progress_args
2066+
)
20932067

2094-
if len(chunk) < limit:
2095-
break
2068+
if len(chunk) < limit:
2069+
break
20962070
except Exception as e:
20972071
raise e
20982072
except Exception as e:
20992073
if not isinstance(e, Client.StopTransmission):
21002074
log.error(e, exc_info=True)
21012075

21022076
try:
2103-
if out:
2104-
os.remove(file_name)
2077+
os.remove(file_name)
21052078
except OSError:
21062079
pass
21072080

0 commit comments

Comments
 (0)