Skip to content

Commit cd027b8

Browse files
committed
Implement missing MTProto checks
1 parent bf9e186 commit cd027b8

File tree

4 files changed

+71
-4
lines changed

4 files changed

+71
-4
lines changed

pyrogram/crypto/mtproto.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@
1919
from hashlib import sha256
2020
from io import BytesIO
2121
from os import urandom
22+
from typing import Optional, List
2223

2324
from pyrogram.raw.core import Message, Long
2425
from . import aes
26+
from ..session.internals import MsgId
2527

2628

2729
def kdf(auth_key: bytes, msg_key: bytes, outgoing: bool) -> tuple:
@@ -49,13 +51,19 @@ def pack(message: Message, salt: int, session_id: bytes, auth_key: bytes, auth_k
4951
return auth_key_id + msg_key + aes.ige256_encrypt(data + padding, aes_key, aes_iv)
5052

5153

52-
def unpack(b: BytesIO, session_id: bytes, auth_key: bytes, auth_key_id: bytes) -> Message:
54+
def unpack(
55+
b: BytesIO,
56+
session_id: bytes,
57+
auth_key: bytes,
58+
auth_key_id: bytes,
59+
stored_msg_ids: List[int]
60+
) -> Optional[Message]:
5361
assert b.read(8) == auth_key_id, b.getvalue()
5462

5563
msg_key = b.read(16)
5664
aes_key, aes_iv = kdf(auth_key, msg_key, False)
5765
data = BytesIO(aes.ige256_decrypt(b.read(), aes_key, aes_iv))
58-
data.read(8)
66+
data.read(8) # Salt
5967

6068
# https://core.telegram.org/mtproto/security_guidelines#checking-session-id
6169
assert data.read(8) == session_id
@@ -75,11 +83,41 @@ def unpack(b: BytesIO, session_id: bytes, auth_key: bytes, auth_key_id: bytes) -
7583
raise ValueError(f"The server sent an unknown constructor: {hex(e.args[0])}\n{left}")
7684

7785
# https://core.telegram.org/mtproto/security_guidelines#checking-sha256-hash-value-of-msg-key
78-
# https://core.telegram.org/mtproto/security_guidelines#checking-message-length
7986
# 96 = 88 + 8 (incoming message)
8087
assert msg_key == sha256(auth_key[96:96 + 32] + data.getvalue()).digest()[8:24]
8188

89+
# https://core.telegram.org/mtproto/security_guidelines#checking-message-length
90+
data.seek(32) # Get to the payload, skip salt (8) + session_id (8) + msg_id (8) + seq_no (4) + length (4)
91+
payload = data.read()
92+
padding = payload[message.length:]
93+
assert 12 <= len(padding) <= 1024
94+
assert len(payload) % 4 == 0
95+
8296
# https://core.telegram.org/mtproto/security_guidelines#checking-msg-id
8397
assert message.msg_id % 2 != 0
8498

99+
if len(stored_msg_ids) > 200:
100+
stored_msg_ids = stored_msg_ids[50:]
101+
102+
if stored_msg_ids:
103+
# Ignored message: msg_id is lower than all of the stored values
104+
if message.msg_id < stored_msg_ids[0]:
105+
return None
106+
107+
# Ignored message: msg_id is equal to any of the stored values
108+
if message.msg_id in stored_msg_ids:
109+
return None
110+
111+
time_diff = (message.msg_id - MsgId()) / 2 ** 32
112+
113+
# Ignored message: msg_id belongs over 30 seconds in the future
114+
if time_diff > 30:
115+
return None
116+
117+
# Ignored message: msg_id belongs over 300 seconds in the past
118+
if time_diff < -300:
119+
return None
120+
121+
stored_msg_ids.append(message.msg_id)
122+
85123
return message

pyrogram/raw/core/future_salt.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,12 @@ def read(data: BytesIO, *args: Any) -> "FutureSalt":
4242
salt = Long.read(data)
4343

4444
return FutureSalt(valid_since, valid_until, salt)
45+
46+
def write(self, *args: Any) -> bytes:
47+
b = BytesIO()
48+
49+
b.write(Int(self.valid_since))
50+
b.write(Int(self.valid_until))
51+
b.write(Long(self.salt))
52+
53+
return b.getvalue()

pyrogram/raw/core/future_salts.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,17 @@ def read(data: BytesIO, *args: Any) -> "FutureSalts":
4545
salts = [FutureSalt.read(data) for _ in range(count)]
4646

4747
return FutureSalts(req_msg_id, now, salts)
48+
49+
def write(self, *args: Any) -> bytes:
50+
b = BytesIO()
51+
52+
b.write(Long(self.req_msg_id))
53+
b.write(Int(self.now))
54+
55+
count = len(self.salts)
56+
b.write(Int(count))
57+
58+
for salt in self.salts:
59+
b.write(salt.write())
60+
61+
return b.getvalue()

pyrogram/session/session.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ def __init__(
102102

103103
self.results = {}
104104

105+
self.stored_msg_ids = []
106+
105107
self.ping_task = None
106108
self.ping_task_event = asyncio.Event()
107109

@@ -224,9 +226,13 @@ async def handle_packet(self, packet):
224226
BytesIO(packet),
225227
self.session_id,
226228
self.auth_key,
227-
self.auth_key_id
229+
self.auth_key_id,
230+
self.stored_msg_ids
228231
)
229232

233+
if data is None:
234+
return
235+
230236
messages = (
231237
data.body.messages
232238
if isinstance(data.body, MsgContainer)

0 commit comments

Comments
 (0)