Skip to content

Commit 3e33ef0

Browse files
committed
Add support for media streams with the method stream_media
1 parent b2c4d26 commit 3e33ef0

File tree

4 files changed

+130
-40
lines changed

4 files changed

+130
-40
lines changed

compiler/docs/compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ def get_title_list(s: str) -> list:
187187
search_global
188188
search_global_count
189189
download_media
190+
stream_media
190191
get_discussion_message
191192
get_discussion_replies
192193
get_discussion_replies_count

pyrogram/client.py

Lines changed: 33 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from io import StringIO, BytesIO
3333
from mimetypes import MimeTypes
3434
from pathlib import Path
35-
from typing import Union, List, Optional, Callable, BinaryIO
35+
from typing import Union, List, Optional, Callable, AsyncGenerator
3636

3737
import pyrogram
3838
from pyrogram import __version__, __license__
@@ -722,13 +722,10 @@ def load_plugins(self):
722722
async def handle_download(self, packet):
723723
file_id, directory, file_name, in_memory, file_size, progress, progress_args = packet
724724

725-
file = await self.get_file(
726-
file_id=file_id,
727-
file_size=file_size,
728-
in_memory=in_memory,
729-
progress=progress,
730-
progress_args=progress_args
731-
)
725+
file = BytesIO() if in_memory else tempfile.NamedTemporaryFile("wb", delete=False)
726+
727+
async for chunk in self.get_file(file_id, file_size, 0, 0, progress, progress_args):
728+
file.write(chunk)
732729

733730
if file and not in_memory:
734731
file_path = os.path.abspath(re.sub("\\\\", "/", os.path.join(directory, file_name)))
@@ -749,11 +746,12 @@ async def handle_download(self, packet):
749746
async def get_file(
750747
self,
751748
file_id: FileId,
752-
file_size: int,
753-
in_memory: bool,
754-
progress: Callable,
749+
file_size: int = 0,
750+
limit: int = 0,
751+
offset: int = 0,
752+
progress: Callable = None,
755753
progress_args: tuple = ()
756-
) -> Optional[BinaryIO]:
754+
) -> Optional[AsyncGenerator[bytes, None]]:
757755
dc_id = file_id.dc_id
758756

759757
async with self.media_sessions_lock:
@@ -836,17 +834,17 @@ async def get_file(
836834
thumb_size=file_id.thumbnail_size
837835
)
838836

839-
limit = 1024 * 1024
840-
offset = 0
841-
842-
file = BytesIO() if in_memory else tempfile.NamedTemporaryFile("wb")
837+
current = 0
838+
total = abs(limit) or (1 << 31) - 1
839+
chunk_size = 1024 * 1024
840+
offset_bytes = abs(offset) * chunk_size
843841

844842
try:
845843
r = await session.invoke(
846844
raw.functions.upload.GetFile(
847845
location=location,
848-
offset=offset,
849-
limit=limit
846+
offset=offset_bytes,
847+
limit=chunk_size
850848
),
851849
sleep_threshold=30
852850
)
@@ -855,16 +853,17 @@ async def get_file(
855853
while True:
856854
chunk = r.bytes
857855

858-
file.write(chunk)
856+
yield chunk
859857

860-
offset += limit
858+
current += 1
859+
offset_bytes += chunk_size
861860

862861
if progress:
863862
func = functools.partial(
864863
progress,
865-
min(offset, file_size)
864+
min(offset_bytes, file_size)
866865
if file_size != 0
867-
else offset,
866+
else offset_bytes,
868867
file_size,
869868
*progress_args
870869
)
@@ -874,14 +873,14 @@ async def get_file(
874873
else:
875874
await self.loop.run_in_executor(self.executor, func)
876875

877-
if len(chunk) < limit:
876+
if len(chunk) < chunk_size or current >= total:
878877
break
879878

880879
r = await session.invoke(
881880
raw.functions.upload.GetFile(
882881
location=location,
883-
offset=offset,
884-
limit=limit
882+
offset=offset_bytes,
883+
limit=chunk_size
885884
),
886885
sleep_threshold=30
887886
)
@@ -905,8 +904,8 @@ async def get_file(
905904
r2 = await cdn_session.invoke(
906905
raw.functions.upload.GetCdnFile(
907906
file_token=r.file_token,
908-
offset=offset,
909-
limit=limit
907+
offset=offset_bytes,
908+
limit=chunk_size
910909
)
911910
)
912911

@@ -931,14 +930,14 @@ async def get_file(
931930
r.encryption_key,
932931
bytearray(
933932
r.encryption_iv[:-4]
934-
+ (offset // 16).to_bytes(4, "big")
933+
+ (offset_bytes // 16).to_bytes(4, "big")
935934
)
936935
)
937936

938937
hashes = await session.invoke(
939938
raw.functions.upload.GetCdnFileHashes(
940939
file_token=r.file_token,
941-
offset=offset
940+
offset=offset_bytes
942941
)
943942
)
944943

@@ -947,14 +946,15 @@ async def get_file(
947946
cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)]
948947
CDNFileHashMismatch.check(h.hash == sha256(cdn_chunk).digest())
949948

950-
file.write(decrypted_chunk)
949+
yield decrypted_chunk
951950

952-
offset += limit
951+
current += 1
952+
offset_bytes += chunk_size
953953

954954
if progress:
955955
func = functools.partial(
956956
progress,
957-
min(offset, file_size) if file_size != 0 else offset,
957+
min(offset_bytes, file_size) if file_size != 0 else offset_bytes,
958958
file_size,
959959
*progress_args
960960
)
@@ -964,20 +964,14 @@ async def get_file(
964964
else:
965965
await self.loop.run_in_executor(self.executor, func)
966966

967-
if len(chunk) < limit:
967+
if len(chunk) < chunk_size or current >= total:
968968
break
969969
except Exception as e:
970970
raise e
971971
except Exception as e:
972972
if not isinstance(e, pyrogram.StopTransmission):
973973
log.error(e, exc_info=True)
974974

975-
file.close()
976-
977-
return None
978-
else:
979-
return file
980-
981975
def guess_mime_type(self, filename: str) -> Optional[str]:
982976
return self.mimetypes.guess_type(filename)[0]
983977

pyrogram/methods/messages/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
from .send_video_note import SendVideoNote
6262
from .send_voice import SendVoice
6363
from .stop_poll import StopPoll
64+
from .stream_media import StreamMedia
6465
from .vote_poll import VotePoll
6566

6667

@@ -110,6 +111,7 @@ class Messages(
110111
GetDiscussionMessage,
111112
SendReaction,
112113
GetDiscussionReplies,
113-
GetDiscussionRepliesCount
114+
GetDiscussionRepliesCount,
115+
StreamMedia
114116
):
115117
pass
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Pyrogram - Telegram MTProto API Client Library for Python
2+
# Copyright (C) 2017-present Dan <https://github.com/delivrance>
3+
#
4+
# This file is part of Pyrogram.
5+
#
6+
# Pyrogram is free software: you can redistribute it and/or modify
7+
# it under the terms of the GNU Lesser General Public License as published
8+
# by the Free Software Foundation, either version 3 of the License, or
9+
# (at your option) any later version.
10+
#
11+
# Pyrogram is distributed in the hope that it will be useful,
12+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
13+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14+
# GNU Lesser General Public License for more details.
15+
#
16+
# You should have received a copy of the GNU Lesser General Public License
17+
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
18+
19+
from typing import Union, Optional, BinaryIO
20+
21+
import pyrogram
22+
from pyrogram import types
23+
from pyrogram.file_id import FileId
24+
25+
26+
class StreamMedia:
27+
async def stream_media(
28+
self: "pyrogram.Client",
29+
message: Union["types.Message", str],
30+
limit: int = 0,
31+
offset: int = 0
32+
) -> Optional[Union[str, BinaryIO]]:
33+
"""Stream the media from a message chunk by chunk.
34+
35+
The chunk size is 1 MiB (1024 * 1024 bytes).
36+
37+
Parameters:
38+
message (:obj:`~pyrogram.types.Message` | ``str``):
39+
Pass a Message containing the media, the media itself (message.audio, message.video, ...) or a file id
40+
as string.
41+
42+
limit (``int``, *optional*):
43+
Limit the amount of chunks to stream.
44+
Defaults to 0 (stream the whole media).
45+
46+
offset (``int``, *optional*):
47+
How many chunks to skip before starting to stream.
48+
Defaults to 0 (start from the beginning).
49+
50+
Returns:
51+
``Generator``: A generator yielding bytes chunk by chunk
52+
53+
Example:
54+
.. code-block:: python
55+
56+
# Stream the whole media
57+
async for chunk in app.stream_media(message):
58+
print(len(chunk))
59+
60+
# Stream the first 3 chunks only
61+
async for chunk in app.stream_media(message, limit=3):
62+
print(len(chunk))
63+
64+
# Stream the last 3 chunks only
65+
import math
66+
chunks = math.ceil(message.document.file_size / 1024 / 1024)
67+
async for chunk in app.stream_media(message, offset=chunks - 3):
68+
print(len(chunk))
69+
"""
70+
available_media = ("audio", "document", "photo", "sticker", "animation", "video", "voice", "video_note",
71+
"new_chat_photo")
72+
73+
if isinstance(message, types.Message):
74+
for kind in available_media:
75+
media = getattr(message, kind, None)
76+
77+
if media is not None:
78+
break
79+
else:
80+
raise ValueError("This message doesn't contain any downloadable media")
81+
else:
82+
media = message
83+
84+
if isinstance(media, str):
85+
file_id_str = media
86+
else:
87+
file_id_str = media.file_id
88+
89+
file_id_obj = FileId.decode(file_id_str)
90+
file_size = getattr(media, "file_size", 0)
91+
92+
async for chunk in self.get_file(file_id_obj, file_size, limit, offset):
93+
yield chunk

0 commit comments

Comments
 (0)