2929from concurrent .futures .thread import ThreadPoolExecutor
3030from hashlib import sha256
3131from importlib import import_module
32- from io import StringIO
32+ from io import StringIO , BytesIO
3333from mimetypes import MimeTypes
3434from pathlib import Path
35- from typing import Union , List , Optional , Callable
35+ from typing import Union , List , Optional , Callable , BinaryIO
3636
3737import pyrogram
3838from pyrogram import __version__ , __license__
@@ -482,34 +482,6 @@ async def fetch_peers(self, peers: List[Union[raw.types.User, raw.types.Chat, ra
482482
483483 return is_min
484484
485- async def handle_download (self , packet ):
486- temp_file_path = ""
487- final_file_path = ""
488-
489- try :
490- file_id , directory , file_name , file_size , progress , progress_args = packet
491-
492- temp_file_path = await self .get_file (
493- file_id = file_id ,
494- file_size = file_size ,
495- progress = progress ,
496- progress_args = progress_args
497- )
498-
499- if temp_file_path :
500- final_file_path = os .path .abspath (re .sub ("\\ \\ " , "/" , os .path .join (directory , file_name )))
501- os .makedirs (directory , exist_ok = True )
502- shutil .move (temp_file_path , final_file_path )
503- except Exception as e :
504- log .error (e , exc_info = True )
505-
506- try :
507- os .remove (temp_file_path )
508- except OSError :
509- pass
510- else :
511- return final_file_path or None
512-
513485 async def handle_updates (self , updates ):
514486 if isinstance (updates , (raw .types .Updates , raw .types .UpdatesCombined )):
515487 is_min = (await self .fetch_peers (updates .users )) or (await self .fetch_peers (updates .chats ))
@@ -747,13 +719,41 @@ def load_plugins(self):
747719 else :
748720 log .warning (f'[{ self .session_name } ] No plugin loaded from "{ root } "' )
749721
722+ async def handle_download (self , packet ):
723+ file_id , directory , file_name , in_memory , file_size , progress , progress_args = packet
724+
725+ file = await self .get_file (
726+ file_id = file_id ,
727+ file_size = file_size ,
728+ in_memory = in_memory ,
729+ progress = progress ,
730+ progress_args = progress_args
731+ )
732+
733+ if file and not in_memory :
734+ file_path = os .path .abspath (re .sub ("\\ \\ " , "/" , os .path .join (directory , file_name )))
735+ os .makedirs (directory , exist_ok = True )
736+ shutil .move (file .name , file_path )
737+
738+ try :
739+ file .close ()
740+ except FileNotFoundError :
741+ pass
742+
743+ return file_path
744+
745+ if file and in_memory :
746+ file .name = file_name
747+ return file
748+
750749 async def get_file (
751750 self ,
752751 file_id : FileId ,
753752 file_size : int ,
753+ in_memory : bool ,
754754 progress : Callable ,
755755 progress_args : tuple = ()
756- ) -> str :
756+ ) -> Optional [ BinaryIO ] :
757757 dc_id = file_id .dc_id
758758
759759 async with self .media_sessions_lock :
@@ -838,7 +838,8 @@ async def get_file(
838838
839839 limit = 1024 * 1024
840840 offset = 0
841- file_name = ""
841+
842+ file = BytesIO () if in_memory else tempfile .NamedTemporaryFile ("wb" )
842843
843844 try :
844845 r = await session .invoke (
@@ -851,43 +852,40 @@ async def get_file(
851852 )
852853
853854 if isinstance (r , raw .types .upload .File ):
854- with tempfile .NamedTemporaryFile ("wb" , delete = False ) as f :
855- file_name = f .name
856-
857- while True :
858- chunk = r .bytes
859-
860- f .write (chunk )
855+ while True :
856+ chunk = r .bytes
861857
862- offset += limit
858+ file . write ( chunk )
863859
864- if progress :
865- func = functools .partial (
866- progress ,
867- min (offset , file_size )
868- if file_size != 0
869- else offset ,
870- file_size ,
871- * progress_args
872- )
873-
874- if inspect .iscoroutinefunction (progress ):
875- await func ()
876- else :
877- await self .loop .run_in_executor (self .executor , func )
860+ offset += limit
878861
879- if len (chunk ) < limit :
880- break
881-
882- r = await session .invoke (
883- raw .functions .upload .GetFile (
884- location = location ,
885- offset = offset ,
886- limit = limit
887- ),
888- sleep_threshold = 30
862+ if progress :
863+ func = functools .partial (
864+ progress ,
865+ min (offset , file_size )
866+ if file_size != 0
867+ else offset ,
868+ file_size ,
869+ * progress_args
889870 )
890871
872+ if inspect .iscoroutinefunction (progress ):
873+ await func ()
874+ else :
875+ await self .loop .run_in_executor (self .executor , func )
876+
877+ if len (chunk ) < limit :
878+ break
879+
880+ r = await session .invoke (
881+ raw .functions .upload .GetFile (
882+ location = location ,
883+ offset = offset ,
884+ limit = limit
885+ ),
886+ sleep_threshold = 30
887+ )
888+
891889 elif isinstance (r , raw .types .upload .FileCdnRedirect ):
892890 async with self .media_sessions_lock :
893891 cdn_session = self .media_sessions .get (r .dc_id , None )
@@ -903,88 +901,82 @@ async def get_file(
903901 self .media_sessions [r .dc_id ] = cdn_session
904902
905903 try :
906- with tempfile .NamedTemporaryFile ("wb" , delete = False ) as f :
907- file_name = f .name
908-
909- while True :
910- r2 = await cdn_session .invoke (
911- raw .functions .upload .GetCdnFile (
912- file_token = r .file_token ,
913- offset = offset ,
914- limit = limit
915- )
904+ while True :
905+ r2 = await cdn_session .invoke (
906+ raw .functions .upload .GetCdnFile (
907+ file_token = r .file_token ,
908+ offset = offset ,
909+ limit = limit
916910 )
911+ )
917912
918- if isinstance (r2 , raw .types .upload .CdnFileReuploadNeeded ):
919- try :
920- await session .invoke (
921- raw .functions .upload .ReuploadCdnFile (
922- file_token = r .file_token ,
923- request_token = r2 .request_token
924- )
913+ if isinstance (r2 , raw .types .upload .CdnFileReuploadNeeded ):
914+ try :
915+ await session .invoke (
916+ raw .functions .upload .ReuploadCdnFile (
917+ file_token = r .file_token ,
918+ request_token = r2 .request_token
925919 )
926- except VolumeLocNotFound :
927- break
928- else :
929- continue
930-
931- chunk = r2 .bytes
932-
933- # https://core.telegram.org/cdn#decrypting-files
934- decrypted_chunk = aes .ctr256_decrypt (
935- chunk ,
936- r .encryption_key ,
937- bytearray (
938- r .encryption_iv [:- 4 ]
939- + (offset // 16 ).to_bytes (4 , "big" )
940920 )
921+ except VolumeLocNotFound :
922+ break
923+ else :
924+ continue
925+
926+ chunk = r2 .bytes
927+
928+ # https://core.telegram.org/cdn#decrypting-files
929+ decrypted_chunk = aes .ctr256_decrypt (
930+ chunk ,
931+ r .encryption_key ,
932+ bytearray (
933+ r .encryption_iv [:- 4 ]
934+ + (offset // 16 ).to_bytes (4 , "big" )
941935 )
936+ )
942937
943- hashes = await session .invoke (
944- raw .functions .upload .GetCdnFileHashes (
945- file_token = r .file_token ,
946- offset = offset
947- )
938+ hashes = await session .invoke (
939+ raw .functions .upload .GetCdnFileHashes (
940+ file_token = r .file_token ,
941+ offset = offset
948942 )
943+ )
949944
950- # https://core.telegram.org/cdn#verifying-files
951- for i , h in enumerate (hashes ):
952- cdn_chunk = decrypted_chunk [h .limit * i : h .limit * (i + 1 )]
953- CDNFileHashMismatch .check (h .hash == sha256 (cdn_chunk ).digest ())
945+ # https://core.telegram.org/cdn#verifying-files
946+ for i , h in enumerate (hashes ):
947+ cdn_chunk = decrypted_chunk [h .limit * i : h .limit * (i + 1 )]
948+ CDNFileHashMismatch .check (h .hash == sha256 (cdn_chunk ).digest ())
954949
955- f .write (decrypted_chunk )
950+ file .write (decrypted_chunk )
956951
957- offset += limit
952+ offset += limit
958953
959- if progress :
960- func = functools .partial (
961- progress ,
962- min (offset , file_size ) if file_size != 0 else offset ,
963- file_size ,
964- * progress_args
965- )
954+ if progress :
955+ func = functools .partial (
956+ progress ,
957+ min (offset , file_size ) if file_size != 0 else offset ,
958+ file_size ,
959+ * progress_args
960+ )
966961
967- if inspect .iscoroutinefunction (progress ):
968- await func ()
969- else :
970- await self .loop .run_in_executor (self .executor , func )
962+ if inspect .iscoroutinefunction (progress ):
963+ await func ()
964+ else :
965+ await self .loop .run_in_executor (self .executor , func )
971966
972- if len (chunk ) < limit :
973- break
967+ if len (chunk ) < limit :
968+ break
974969 except Exception as e :
975970 raise e
976971 except Exception as e :
977972 if not isinstance (e , pyrogram .StopTransmission ):
978973 log .error (e , exc_info = True )
979974
980- try :
981- os .remove (file_name )
982- except OSError :
983- pass
975+ file .close ()
984976
985- return ""
977+ return None
986978 else :
987- return file_name
979+ return file
988980
989981 def guess_mime_type (self , filename : str ) -> Optional [str ]:
990982 return self .mimetypes .guess_type (filename )[0 ]
0 commit comments