Skip to content

Commit 01ca652

Browse files
committed
Add support for in-memory downloads
1 parent 0d054fa commit 01ca652

3 files changed

Lines changed: 152 additions & 134 deletions

File tree

pyrogram/client.py

Lines changed: 117 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@
2929
from concurrent.futures.thread import ThreadPoolExecutor
3030
from hashlib import sha256
3131
from importlib import import_module
32-
from io import StringIO
32+
from io import StringIO, BytesIO
3333
from mimetypes import MimeTypes
3434
from pathlib import Path
35-
from typing import Union, List, Optional, Callable
35+
from typing import Union, List, Optional, Callable, BinaryIO
3636

3737
import pyrogram
3838
from pyrogram import __version__, __license__
@@ -482,34 +482,6 @@ async def fetch_peers(self, peers: List[Union[raw.types.User, raw.types.Chat, ra
482482

483483
return is_min
484484

485-
async def handle_download(self, packet):
486-
temp_file_path = ""
487-
final_file_path = ""
488-
489-
try:
490-
file_id, directory, file_name, file_size, progress, progress_args = packet
491-
492-
temp_file_path = await self.get_file(
493-
file_id=file_id,
494-
file_size=file_size,
495-
progress=progress,
496-
progress_args=progress_args
497-
)
498-
499-
if temp_file_path:
500-
final_file_path = os.path.abspath(re.sub("\\\\", "/", os.path.join(directory, file_name)))
501-
os.makedirs(directory, exist_ok=True)
502-
shutil.move(temp_file_path, final_file_path)
503-
except Exception as e:
504-
log.error(e, exc_info=True)
505-
506-
try:
507-
os.remove(temp_file_path)
508-
except OSError:
509-
pass
510-
else:
511-
return final_file_path or None
512-
513485
async def handle_updates(self, updates):
514486
if isinstance(updates, (raw.types.Updates, raw.types.UpdatesCombined)):
515487
is_min = (await self.fetch_peers(updates.users)) or (await self.fetch_peers(updates.chats))
@@ -747,13 +719,41 @@ def load_plugins(self):
747719
else:
748720
log.warning(f'[{self.session_name}] No plugin loaded from "{root}"')
749721

722+
async def handle_download(self, packet):
723+
file_id, directory, file_name, in_memory, file_size, progress, progress_args = packet
724+
725+
file = await self.get_file(
726+
file_id=file_id,
727+
file_size=file_size,
728+
in_memory=in_memory,
729+
progress=progress,
730+
progress_args=progress_args
731+
)
732+
733+
if file and not in_memory:
734+
file_path = os.path.abspath(re.sub("\\\\", "/", os.path.join(directory, file_name)))
735+
os.makedirs(directory, exist_ok=True)
736+
shutil.move(file.name, file_path)
737+
738+
try:
739+
file.close()
740+
except FileNotFoundError:
741+
pass
742+
743+
return file_path
744+
745+
if file and in_memory:
746+
file.name = file_name
747+
return file
748+
750749
async def get_file(
751750
self,
752751
file_id: FileId,
753752
file_size: int,
753+
in_memory: bool,
754754
progress: Callable,
755755
progress_args: tuple = ()
756-
) -> str:
756+
) -> Optional[BinaryIO]:
757757
dc_id = file_id.dc_id
758758

759759
async with self.media_sessions_lock:
@@ -838,7 +838,8 @@ async def get_file(
838838

839839
limit = 1024 * 1024
840840
offset = 0
841-
file_name = ""
841+
842+
file = BytesIO() if in_memory else tempfile.NamedTemporaryFile("wb")
842843

843844
try:
844845
r = await session.invoke(
@@ -851,43 +852,40 @@ async def get_file(
851852
)
852853

853854
if isinstance(r, raw.types.upload.File):
854-
with tempfile.NamedTemporaryFile("wb", delete=False) as f:
855-
file_name = f.name
856-
857-
while True:
858-
chunk = r.bytes
859-
860-
f.write(chunk)
855+
while True:
856+
chunk = r.bytes
861857

862-
offset += limit
858+
file.write(chunk)
863859

864-
if progress:
865-
func = functools.partial(
866-
progress,
867-
min(offset, file_size)
868-
if file_size != 0
869-
else offset,
870-
file_size,
871-
*progress_args
872-
)
873-
874-
if inspect.iscoroutinefunction(progress):
875-
await func()
876-
else:
877-
await self.loop.run_in_executor(self.executor, func)
860+
offset += limit
878861

879-
if len(chunk) < limit:
880-
break
881-
882-
r = await session.invoke(
883-
raw.functions.upload.GetFile(
884-
location=location,
885-
offset=offset,
886-
limit=limit
887-
),
888-
sleep_threshold=30
862+
if progress:
863+
func = functools.partial(
864+
progress,
865+
min(offset, file_size)
866+
if file_size != 0
867+
else offset,
868+
file_size,
869+
*progress_args
889870
)
890871

872+
if inspect.iscoroutinefunction(progress):
873+
await func()
874+
else:
875+
await self.loop.run_in_executor(self.executor, func)
876+
877+
if len(chunk) < limit:
878+
break
879+
880+
r = await session.invoke(
881+
raw.functions.upload.GetFile(
882+
location=location,
883+
offset=offset,
884+
limit=limit
885+
),
886+
sleep_threshold=30
887+
)
888+
891889
elif isinstance(r, raw.types.upload.FileCdnRedirect):
892890
async with self.media_sessions_lock:
893891
cdn_session = self.media_sessions.get(r.dc_id, None)
@@ -903,88 +901,82 @@ async def get_file(
903901
self.media_sessions[r.dc_id] = cdn_session
904902

905903
try:
906-
with tempfile.NamedTemporaryFile("wb", delete=False) as f:
907-
file_name = f.name
908-
909-
while True:
910-
r2 = await cdn_session.invoke(
911-
raw.functions.upload.GetCdnFile(
912-
file_token=r.file_token,
913-
offset=offset,
914-
limit=limit
915-
)
904+
while True:
905+
r2 = await cdn_session.invoke(
906+
raw.functions.upload.GetCdnFile(
907+
file_token=r.file_token,
908+
offset=offset,
909+
limit=limit
916910
)
911+
)
917912

918-
if isinstance(r2, raw.types.upload.CdnFileReuploadNeeded):
919-
try:
920-
await session.invoke(
921-
raw.functions.upload.ReuploadCdnFile(
922-
file_token=r.file_token,
923-
request_token=r2.request_token
924-
)
913+
if isinstance(r2, raw.types.upload.CdnFileReuploadNeeded):
914+
try:
915+
await session.invoke(
916+
raw.functions.upload.ReuploadCdnFile(
917+
file_token=r.file_token,
918+
request_token=r2.request_token
925919
)
926-
except VolumeLocNotFound:
927-
break
928-
else:
929-
continue
930-
931-
chunk = r2.bytes
932-
933-
# https://core.telegram.org/cdn#decrypting-files
934-
decrypted_chunk = aes.ctr256_decrypt(
935-
chunk,
936-
r.encryption_key,
937-
bytearray(
938-
r.encryption_iv[:-4]
939-
+ (offset // 16).to_bytes(4, "big")
940920
)
921+
except VolumeLocNotFound:
922+
break
923+
else:
924+
continue
925+
926+
chunk = r2.bytes
927+
928+
# https://core.telegram.org/cdn#decrypting-files
929+
decrypted_chunk = aes.ctr256_decrypt(
930+
chunk,
931+
r.encryption_key,
932+
bytearray(
933+
r.encryption_iv[:-4]
934+
+ (offset // 16).to_bytes(4, "big")
941935
)
936+
)
942937

943-
hashes = await session.invoke(
944-
raw.functions.upload.GetCdnFileHashes(
945-
file_token=r.file_token,
946-
offset=offset
947-
)
938+
hashes = await session.invoke(
939+
raw.functions.upload.GetCdnFileHashes(
940+
file_token=r.file_token,
941+
offset=offset
948942
)
943+
)
949944

950-
# https://core.telegram.org/cdn#verifying-files
951-
for i, h in enumerate(hashes):
952-
cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)]
953-
CDNFileHashMismatch.check(h.hash == sha256(cdn_chunk).digest())
945+
# https://core.telegram.org/cdn#verifying-files
946+
for i, h in enumerate(hashes):
947+
cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)]
948+
CDNFileHashMismatch.check(h.hash == sha256(cdn_chunk).digest())
954949

955-
f.write(decrypted_chunk)
950+
file.write(decrypted_chunk)
956951

957-
offset += limit
952+
offset += limit
958953

959-
if progress:
960-
func = functools.partial(
961-
progress,
962-
min(offset, file_size) if file_size != 0 else offset,
963-
file_size,
964-
*progress_args
965-
)
954+
if progress:
955+
func = functools.partial(
956+
progress,
957+
min(offset, file_size) if file_size != 0 else offset,
958+
file_size,
959+
*progress_args
960+
)
966961

967-
if inspect.iscoroutinefunction(progress):
968-
await func()
969-
else:
970-
await self.loop.run_in_executor(self.executor, func)
962+
if inspect.iscoroutinefunction(progress):
963+
await func()
964+
else:
965+
await self.loop.run_in_executor(self.executor, func)
971966

972-
if len(chunk) < limit:
973-
break
967+
if len(chunk) < limit:
968+
break
974969
except Exception as e:
975970
raise e
976971
except Exception as e:
977972
if not isinstance(e, pyrogram.StopTransmission):
978973
log.error(e, exc_info=True)
979974

980-
try:
981-
os.remove(file_name)
982-
except OSError:
983-
pass
975+
file.close()
984976

985-
return ""
977+
return None
986978
else:
987-
return file_name
979+
return file
988980

989981
def guess_mime_type(self, filename: str) -> Optional[str]:
990982
return self.mimetypes.guess_type(filename)[0]

0 commit comments

Comments
 (0)