3232from io import StringIO , BytesIO
3333from mimetypes import MimeTypes
3434from pathlib import Path
35- from typing import Union , List , Optional , Callable , BinaryIO
35+ from typing import Union , List , Optional , Callable , AsyncGenerator
3636
3737import pyrogram
3838from pyrogram import __version__ , __license__
@@ -722,13 +722,10 @@ def load_plugins(self):
722722 async def handle_download (self , packet ):
723723 file_id , directory , file_name , in_memory , file_size , progress , progress_args = packet
724724
725- file = await self .get_file (
726- file_id = file_id ,
727- file_size = file_size ,
728- in_memory = in_memory ,
729- progress = progress ,
730- progress_args = progress_args
731- )
725+ file = BytesIO () if in_memory else tempfile .NamedTemporaryFile ("wb" , delete = False )
726+
727+ async for chunk in self .get_file (file_id , file_size , 0 , 0 , progress , progress_args ):
728+ file .write (chunk )
732729
733730 if file and not in_memory :
734731 file_path = os .path .abspath (re .sub ("\\ \\ " , "/" , os .path .join (directory , file_name )))
@@ -749,11 +746,12 @@ async def handle_download(self, packet):
749746 async def get_file (
750747 self ,
751748 file_id : FileId ,
752- file_size : int ,
753- in_memory : bool ,
754- progress : Callable ,
749+ file_size : int = 0 ,
750+ limit : int = 0 ,
751+ offset : int = 0 ,
752+ progress : Callable = None ,
755753 progress_args : tuple = ()
756- ) -> Optional [BinaryIO ]:
754+ ) -> Optional [AsyncGenerator [ bytes , None ] ]:
757755 dc_id = file_id .dc_id
758756
759757 async with self .media_sessions_lock :
@@ -836,17 +834,17 @@ async def get_file(
836834 thumb_size = file_id .thumbnail_size
837835 )
838836
839- limit = 1024 * 1024
840- offset = 0
841-
842- file = BytesIO () if in_memory else tempfile . NamedTemporaryFile ( "wb" )
837+ current = 0
838+ total = abs ( limit ) or ( 1 << 31 ) - 1
839+ chunk_size = 1024 * 1024
840+ offset_bytes = abs ( offset ) * chunk_size
843841
844842 try :
845843 r = await session .invoke (
846844 raw .functions .upload .GetFile (
847845 location = location ,
848- offset = offset ,
849- limit = limit
846+ offset = offset_bytes ,
847+ limit = chunk_size
850848 ),
851849 sleep_threshold = 30
852850 )
@@ -855,16 +853,17 @@ async def get_file(
855853 while True :
856854 chunk = r .bytes
857855
858- file . write ( chunk )
856+ yield chunk
859857
860- offset += limit
858+ current += 1
859+ offset_bytes += chunk_size
861860
862861 if progress :
863862 func = functools .partial (
864863 progress ,
865- min (offset , file_size )
864+ min (offset_bytes , file_size )
866865 if file_size != 0
867- else offset ,
866+ else offset_bytes ,
868867 file_size ,
869868 * progress_args
870869 )
@@ -874,14 +873,14 @@ async def get_file(
874873 else :
875874 await self .loop .run_in_executor (self .executor , func )
876875
877- if len (chunk ) < limit :
876+ if len (chunk ) < chunk_size or current >= total :
878877 break
879878
880879 r = await session .invoke (
881880 raw .functions .upload .GetFile (
882881 location = location ,
883- offset = offset ,
884- limit = limit
882+ offset = offset_bytes ,
883+ limit = chunk_size
885884 ),
886885 sleep_threshold = 30
887886 )
@@ -905,8 +904,8 @@ async def get_file(
905904 r2 = await cdn_session .invoke (
906905 raw .functions .upload .GetCdnFile (
907906 file_token = r .file_token ,
908- offset = offset ,
909- limit = limit
907+ offset = offset_bytes ,
908+ limit = chunk_size
910909 )
911910 )
912911
@@ -931,14 +930,14 @@ async def get_file(
931930 r .encryption_key ,
932931 bytearray (
933932 r .encryption_iv [:- 4 ]
934- + (offset // 16 ).to_bytes (4 , "big" )
933+ + (offset_bytes // 16 ).to_bytes (4 , "big" )
935934 )
936935 )
937936
938937 hashes = await session .invoke (
939938 raw .functions .upload .GetCdnFileHashes (
940939 file_token = r .file_token ,
941- offset = offset
940+ offset = offset_bytes
942941 )
943942 )
944943
@@ -947,14 +946,15 @@ async def get_file(
947946 cdn_chunk = decrypted_chunk [h .limit * i : h .limit * (i + 1 )]
948947 CDNFileHashMismatch .check (h .hash == sha256 (cdn_chunk ).digest ())
949948
950- file . write ( decrypted_chunk )
949+ yield decrypted_chunk
951950
952- offset += limit
951+ current += 1
952+ offset_bytes += chunk_size
953953
954954 if progress :
955955 func = functools .partial (
956956 progress ,
957- min (offset , file_size ) if file_size != 0 else offset ,
957+ min (offset_bytes , file_size ) if file_size != 0 else offset_bytes ,
958958 file_size ,
959959 * progress_args
960960 )
@@ -964,20 +964,14 @@ async def get_file(
964964 else :
965965 await self .loop .run_in_executor (self .executor , func )
966966
967- if len (chunk ) < limit :
967+ if len (chunk ) < chunk_size or current >= total :
968968 break
969969 except Exception as e :
970970 raise e
971971 except Exception as e :
972972 if not isinstance (e , pyrogram .StopTransmission ):
973973 log .error (e , exc_info = True )
974974
975- file .close ()
976-
977- return None
978- else :
979- return file
980-
981975 def guess_mime_type (self , filename : str ) -> Optional [str ]:
982976 return self .mimetypes .guess_type (filename )[0 ]
983977
0 commit comments