Skip to content

Commit 55599e3

Browse files
committed
Rework download_media to accommodate L100 changes
1 parent 3208b22 commit 55599e3

File tree

3 files changed

+165
-126
lines changed

3 files changed

+165
-126
lines changed

pyrogram/client/client.py

Lines changed: 78 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,13 @@
1717
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
1818

1919
import base64
20-
import binascii
2120
import json
2221
import logging
2322
import math
2423
import mimetypes
2524
import os
2625
import re
2726
import shutil
28-
import struct
2927
import tempfile
3028
import threading
3129
import time
@@ -49,7 +47,7 @@
4947
PhoneNumberUnoccupied, PhoneCodeInvalid, PhoneCodeHashEmpty,
5048
PhoneCodeExpired, PhoneCodeEmpty, SessionPasswordNeeded,
5149
PasswordHashInvalid, FloodWait, PeerIdInvalid, FirstnameInvalid, PhoneNumberBanned,
52-
VolumeLocNotFound, UserMigrate, FileIdInvalid, ChannelPrivate, PhoneNumberOccupied,
50+
VolumeLocNotFound, UserMigrate, ChannelPrivate, PhoneNumberOccupied,
5351
PasswordRecoveryNa, PasswordEmpty
5452
)
5553
from pyrogram.session import Auth, Session
@@ -829,85 +827,59 @@ def download_worker(self):
829827
log.debug("{} started".format(name))
830828

831829
while True:
832-
media = self.download_queue.get()
830+
packet = self.download_queue.get()
833831

834-
if media is None:
832+
if packet is None:
835833
break
836834

837835
temp_file_path = ""
838836
final_file_path = ""
839837

840838
try:
841-
media, file_name, done, progress, progress_args, path = media
842-
843-
file_id = media.file_id
844-
size = media.file_size
839+
data, file_name, done, progress, progress_args, path = packet
840+
data = data # type: BaseClient.FileData
845841

846842
directory, file_name = os.path.split(file_name)
847843
directory = directory or "downloads"
848844

849-
try:
850-
decoded = utils.decode(file_id)
851-
fmt = "<iiqqqqi" if len(decoded) > 24 else "<iiqq"
852-
unpacked = struct.unpack(fmt, decoded)
853-
except (AssertionError, binascii.Error, struct.error):
854-
raise FileIdInvalid from None
855-
else:
856-
media_type = unpacked[0]
857-
dc_id = unpacked[1]
858-
id = unpacked[2]
859-
access_hash = unpacked[3]
860-
volume_id = None
861-
secret = None
862-
local_id = None
863-
864-
if len(decoded) > 24:
865-
volume_id = unpacked[4]
866-
secret = unpacked[5]
867-
local_id = unpacked[6]
868-
869-
media_type_str = Client.MEDIA_TYPE_ID.get(media_type, None)
845+
media_type_str = Client.MEDIA_TYPE_ID[data.media_type]
870846

871-
if media_type_str is None:
872-
raise FileIdInvalid("Unknown media type: {}".format(unpacked[0]))
847+
if not data.file_name:
848+
guessed_extension = self.guess_extension(data.mime_type)
873849

874-
file_name = file_name or getattr(media, "file_name", None)
875-
876-
if not file_name:
877-
guessed_extension = self.guess_extension(media.mime_type)
878-
879-
if media_type in (0, 1, 2):
850+
if data.media_type in (0, 1, 2, 14):
880851
extension = ".jpg"
881-
elif media_type == 3:
852+
elif data.media_type == 3:
882853
extension = guessed_extension or ".ogg"
883-
elif media_type in (4, 10, 13):
854+
elif data.media_type in (4, 10, 13):
884855
extension = guessed_extension or ".mp4"
885-
elif media_type == 5:
856+
elif data.media_type == 5:
886857
extension = guessed_extension or ".zip"
887-
elif media_type == 8:
858+
elif data.media_type == 8:
888859
extension = guessed_extension or ".webp"
889-
elif media_type == 9:
860+
elif data.media_type == 9:
890861
extension = guessed_extension or ".mp3"
891862
else:
892863
continue
893864

894865
file_name = "{}_{}_{}{}".format(
895866
media_type_str,
896-
datetime.fromtimestamp(
897-
getattr(media, "date", None) or time.time()
898-
).strftime("%Y-%m-%d_%H-%M-%S"),
867+
datetime.fromtimestamp(data.date or time.time()).strftime("%Y-%m-%d_%H-%M-%S"),
899868
self.rnd_id(),
900869
extension
901870
)
902871

903872
temp_file_path = self.get_file(
904-
dc_id=dc_id,
905-
id=id,
906-
access_hash=access_hash,
907-
volume_id=volume_id,
908-
local_id=local_id,
909-
secret=secret,
910-
size=size,
873+
media_type=data.media_type,
874+
dc_id=data.dc_id,
875+
file_id=data.file_id,
876+
access_hash=data.access_hash,
877+
thumb_size=data.thumb_size,
878+
peer_id=data.peer_id,
879+
volume_id=data.volume_id,
880+
local_id=data.local_id,
881+
file_size=data.file_size,
882+
is_big=data.is_big,
911883
progress=progress,
912884
progress_args=progress_args
913885
)
@@ -1549,16 +1521,21 @@ def save_file(
15491521
finally:
15501522
session.stop()
15511523

1552-
def get_file(self,
1553-
dc_id: int,
1554-
id: int = None,
1555-
access_hash: int = None,
1556-
volume_id: int = None,
1557-
local_id: int = None,
1558-
secret: int = None,
1559-
size: int = None,
1560-
progress: callable = None,
1561-
progress_args: tuple = ()) -> str:
1524+
def get_file(
1525+
self,
1526+
media_type: int,
1527+
dc_id: int,
1528+
file_id: int,
1529+
access_hash: int,
1530+
thumb_size: str,
1531+
peer_id: int,
1532+
volume_id: int,
1533+
local_id: int,
1534+
file_size: int,
1535+
is_big: bool,
1536+
progress: callable,
1537+
progress_args: tuple = ()
1538+
) -> str:
15621539
with self.media_sessions_lock:
15631540
session = self.media_sessions.get(dc_id, None)
15641541

@@ -1599,18 +1576,33 @@ def get_file(self,
15991576

16001577
self.media_sessions[dc_id] = session
16011578

1602-
if volume_id: # Photos are accessed by volume_id, local_id, secret
1603-
location = types.InputFileLocation(
1579+
if media_type == 1:
1580+
location = types.InputPeerPhotoFileLocation(
1581+
peer=self.resolve_peer(peer_id),
16041582
volume_id=volume_id,
16051583
local_id=local_id,
1606-
secret=secret,
1607-
file_reference=b""
1584+
big=is_big or None
16081585
)
1609-
else: # Any other file can be more easily accessed by id and access_hash
1586+
elif media_type in (0, 2):
1587+
location = types.InputPhotoFileLocation(
1588+
id=file_id,
1589+
access_hash=access_hash,
1590+
file_reference=b"",
1591+
thumb_size=thumb_size
1592+
)
1593+
elif media_type == 14:
1594+
location = types.InputDocumentFileLocation(
1595+
id=file_id,
1596+
access_hash=access_hash,
1597+
file_reference=b"",
1598+
thumb_size=thumb_size
1599+
)
1600+
else:
16101601
location = types.InputDocumentFileLocation(
1611-
id=id,
1602+
id=file_id,
16121603
access_hash=access_hash,
1613-
file_reference=b""
1604+
file_reference=b"",
1605+
thumb_size=""
16141606
)
16151607

16161608
limit = 1024 * 1024
@@ -1641,7 +1633,14 @@ def get_file(self,
16411633
offset += limit
16421634

16431635
if progress:
1644-
progress(self, min(offset, size) if size != 0 else offset, size, *progress_args)
1636+
progress(
1637+
self,
1638+
min(offset, file_size)
1639+
if file_size != 0
1640+
else offset,
1641+
file_size,
1642+
*progress_args
1643+
)
16451644

16461645
r = session.send(
16471646
functions.upload.GetFile(
@@ -1723,7 +1722,14 @@ def get_file(self,
17231722
offset += limit
17241723

17251724
if progress:
1726-
progress(self, min(offset, size) if size != 0 else offset, size, *progress_args)
1725+
progress(
1726+
self,
1727+
min(offset, file_size)
1728+
if file_size != 0
1729+
else offset,
1730+
file_size,
1731+
*progress_args
1732+
)
17271733

17281734
if len(chunk) < limit:
17291735
break

pyrogram/client/ext/base_client.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import os
2020
import platform
2121
import re
22+
from collections import namedtuple
2223
from queue import Queue
2324
from threading import Lock
2425

@@ -56,7 +57,7 @@ class StopTransmission(StopIteration):
5657
CONFIG_FILE = "./config.ini"
5758

5859
MEDIA_TYPE_ID = {
59-
0: "thumbnail",
60+
0: "photo_thumbnail",
6061
1: "chat_photo",
6162
2: "photo",
6263
3: "voice",
@@ -65,7 +66,8 @@ class StopTransmission(StopIteration):
6566
8: "sticker",
6667
9: "audio",
6768
10: "animation",
68-
13: "video_note"
69+
13: "video_note",
70+
14: "document_thumbnail"
6971
}
7072

7173
mime_types_to_extensions = {}
@@ -82,6 +84,10 @@ class StopTransmission(StopIteration):
8284

8385
mime_types_to_extensions[mime_type] = " ".join(extensions)
8486

87+
fields = ("media_type", "dc_id", "file_id", "access_hash", "thumb_size", "peer_id", "volume_id", "local_id",
88+
"is_big", "file_size", "mime_type", "file_name", "date")
89+
FileData = namedtuple("FileData", fields, defaults=(None,) * len(fields))
90+
8591
def __init__(self):
8692
self.is_bot = None
8793
self.dc_id = None

0 commit comments

Comments
 (0)