Skip to content

Commit 6cc9688

Browse files
committed
Implement FileStorage and MemoryStorage engines
1 parent 6177abb commit 6cc9688

File tree

3 files changed

+377
-0
lines changed

3 files changed

+377
-0
lines changed
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Pyrogram - Telegram MTProto API Client Library for Python
2+
# Copyright (C) 2017-2019 Dan Tès <https://github.com/delivrance>
3+
#
4+
# This file is part of Pyrogram.
5+
#
6+
# Pyrogram is free software: you can redistribute it and/or modify
7+
# it under the terms of the GNU Lesser General Public License as published
8+
# by the Free Software Foundation, either version 3 of the License, or
9+
# (at your option) any later version.
10+
#
11+
# Pyrogram is distributed in the hope that it will be useful,
12+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
13+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14+
# GNU Lesser General Public License for more details.
15+
#
16+
# You should have received a copy of the GNU Lesser General Public License
17+
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
18+
19+
import base64
20+
import json
21+
import logging
22+
import os
23+
import sqlite3
24+
from pathlib import Path
25+
from sqlite3 import DatabaseError
26+
from threading import Lock
27+
from typing import Union
28+
29+
from .memory_storage import MemoryStorage
30+
31+
log = logging.getLogger(__name__)
32+
33+
34+
class FileStorage(MemoryStorage):
35+
FILE_EXTENSION = ".session"
36+
37+
def __init__(self, name: str, workdir: Path):
38+
super().__init__(name)
39+
40+
self.workdir = workdir
41+
self.database = workdir / (self.name + self.FILE_EXTENSION)
42+
self.conn = None # type: sqlite3.Connection
43+
self.lock = Lock()
44+
45+
# noinspection PyAttributeOutsideInit
46+
def migrate_from_json(self, path: Union[str, Path]):
47+
log.warning("JSON session storage detected! Pyrogram will now convert it into an SQLite session storage...")
48+
49+
with open(path, encoding="utf-8") as f:
50+
json_session = json.load(f)
51+
52+
os.remove(path)
53+
54+
self.open()
55+
56+
self.dc_id = json_session["dc_id"]
57+
self.test_mode = json_session["test_mode"]
58+
self.auth_key = base64.b64decode("".join(json_session["auth_key"]))
59+
self.user_id = json_session["user_id"]
60+
self.date = json_session.get("date", 0)
61+
self.is_bot = json_session.get("is_bot", False)
62+
63+
peers_by_id = json_session.get("peers_by_id", {})
64+
peers_by_phone = json_session.get("peers_by_phone", {})
65+
66+
peers = {}
67+
68+
for k, v in peers_by_id.items():
69+
if v is None:
70+
type_ = "group"
71+
elif k.startswith("-100"):
72+
type_ = "channel"
73+
else:
74+
type_ = "user"
75+
76+
peers[int(k)] = [int(k), int(v) if v is not None else None, type_, None, None]
77+
78+
for k, v in peers_by_phone.items():
79+
peers[v][4] = k
80+
81+
# noinspection PyTypeChecker
82+
self.update_peers(peers.values())
83+
84+
log.warning("Done! The session has been successfully converted from JSON to SQLite storage")
85+
86+
def open(self):
87+
database_exists = os.path.isfile(self.database)
88+
89+
self.conn = sqlite3.connect(
90+
str(self.database),
91+
timeout=1,
92+
check_same_thread=False
93+
)
94+
95+
try:
96+
if not database_exists:
97+
self.create()
98+
99+
with self.conn:
100+
self.conn.execute("VACUUM")
101+
except DatabaseError:
102+
self.migrate_from_json(self.database)
Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
# Pyrogram - Telegram MTProto API Client Library for Python
2+
# Copyright (C) 2017-2019 Dan Tès <https://github.com/delivrance>
3+
#
4+
# This file is part of Pyrogram.
5+
#
6+
# Pyrogram is free software: you can redistribute it and/or modify
7+
# it under the terms of the GNU Lesser General Public License as published
8+
# by the Free Software Foundation, either version 3 of the License, or
9+
# (at your option) any later version.
10+
#
11+
# Pyrogram is distributed in the hope that it will be useful,
12+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
13+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14+
# GNU Lesser General Public License for more details.
15+
#
16+
# You should have received a copy of the GNU Lesser General Public License
17+
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
18+
19+
import base64
20+
import inspect
21+
import logging
22+
import sqlite3
23+
import struct
24+
import time
25+
from pathlib import Path
26+
from threading import Lock
27+
from typing import List, Tuple
28+
29+
from pyrogram.api import types
30+
from pyrogram.client.storage.storage import Storage
31+
32+
log = logging.getLogger(__name__)
33+
34+
35+
class MemoryStorage(Storage):
36+
SCHEMA_VERSION = 1
37+
USERNAME_TTL = 8 * 60 * 60
38+
SESSION_STRING_FMT = ">B?256sI?"
39+
SESSION_STRING_SIZE = 351
40+
41+
def __init__(self, name: str):
42+
super().__init__(name)
43+
44+
self.conn = None # type: sqlite3.Connection
45+
self.lock = Lock()
46+
47+
def create(self):
48+
with self.lock, self.conn:
49+
with open(Path(__file__).parent / "schema.sql", "r") as schema:
50+
self.conn.executescript(schema.read())
51+
52+
self.conn.execute(
53+
"INSERT INTO version VALUES (?)",
54+
(self.SCHEMA_VERSION,)
55+
)
56+
57+
self.conn.execute(
58+
"INSERT INTO sessions VALUES (?, ?, ?, ?, ?, ?)",
59+
(1, None, None, 0, None, None)
60+
)
61+
62+
def _import_session_string(self, string_session: str):
63+
decoded = base64.urlsafe_b64decode(string_session + "=" * (-len(string_session) % 4))
64+
return struct.unpack(self.SESSION_STRING_FMT, decoded)
65+
66+
def export_session_string(self):
67+
packed = struct.pack(
68+
self.SESSION_STRING_FMT,
69+
self.dc_id,
70+
self.test_mode,
71+
self.auth_key,
72+
self.user_id,
73+
self.is_bot
74+
)
75+
76+
return base64.urlsafe_b64encode(packed).decode().rstrip("=")
77+
78+
# noinspection PyAttributeOutsideInit
79+
def open(self):
80+
self.conn = sqlite3.connect(":memory:", check_same_thread=False)
81+
self.create()
82+
83+
if self.name != ":memory:":
84+
imported_session_string = self._import_session_string(self.name)
85+
86+
self.dc_id, self.test_mode, self.auth_key, self.user_id, self.is_bot = imported_session_string
87+
self.date = 0
88+
89+
self.name = ":memory:" + str(self.user_id or "<unknown>")
90+
91+
# noinspection PyAttributeOutsideInit
92+
def save(self):
93+
self.date = int(time.time())
94+
95+
with self.lock:
96+
self.conn.commit()
97+
98+
def close(self):
99+
with self.lock:
100+
self.conn.close()
101+
102+
def update_peers(self, peers: List[Tuple[int, int, str, str, str]]):
103+
with self.lock:
104+
self.conn.executemany(
105+
"REPLACE INTO peers (id, access_hash, type, username, phone_number)"
106+
"VALUES (?, ?, ?, ?, ?)",
107+
peers
108+
)
109+
110+
def clear_peers(self):
111+
with self.lock, self.conn:
112+
self.conn.execute(
113+
"DELETE FROM peers"
114+
)
115+
116+
@staticmethod
117+
def _get_input_peer(peer_id: int, access_hash: int, peer_type: str):
118+
if peer_type in ["user", "bot"]:
119+
return types.InputPeerUser(
120+
user_id=peer_id,
121+
access_hash=access_hash
122+
)
123+
124+
if peer_type == "group":
125+
return types.InputPeerChat(
126+
chat_id=-peer_id
127+
)
128+
129+
if peer_type in ["channel", "supergroup"]:
130+
return types.InputPeerChannel(
131+
channel_id=int(str(peer_id)[4:]),
132+
access_hash=access_hash
133+
)
134+
135+
raise ValueError("Invalid peer type")
136+
137+
def get_peer_by_id(self, peer_id: int):
138+
r = self.conn.execute(
139+
"SELECT id, access_hash, type FROM peers WHERE id = ?",
140+
(peer_id,)
141+
).fetchone()
142+
143+
if r is None:
144+
raise KeyError("ID not found")
145+
146+
return self._get_input_peer(*r)
147+
148+
def get_peer_by_username(self, username: str):
149+
r = self.conn.execute(
150+
"SELECT id, access_hash, type, last_update_on FROM peers WHERE username = ?",
151+
(username,)
152+
).fetchone()
153+
154+
if r is None:
155+
raise KeyError("Username not found")
156+
157+
if abs(time.time() - r[3]) > self.USERNAME_TTL:
158+
raise KeyError("Username expired")
159+
160+
return self._get_input_peer(*r[:3])
161+
162+
def get_peer_by_phone_number(self, phone_number: str):
163+
r = self.conn.execute(
164+
"SELECT id, access_hash, type FROM peers WHERE phone_number = ?",
165+
(phone_number,)
166+
).fetchone()
167+
168+
if r is None:
169+
raise KeyError("Phone number not found")
170+
171+
return self._get_input_peer(*r)
172+
173+
@property
174+
def peers_count(self):
175+
return self.conn.execute(
176+
"SELECT COUNT(*) FROM peers"
177+
).fetchone()[0]
178+
179+
def _get(self):
180+
attr = inspect.stack()[1].function
181+
182+
return self.conn.execute(
183+
"SELECT {} FROM sessions".format(attr)
184+
).fetchone()[0]
185+
186+
def _set(self, value):
187+
attr = inspect.stack()[1].function
188+
189+
with self.lock, self.conn:
190+
self.conn.execute(
191+
"UPDATE sessions SET {} = ?".format(attr),
192+
(value,)
193+
)
194+
195+
@property
196+
def dc_id(self):
197+
return self._get()
198+
199+
@dc_id.setter
200+
def dc_id(self, value):
201+
self._set(value)
202+
203+
@property
204+
def test_mode(self):
205+
return self._get()
206+
207+
@test_mode.setter
208+
def test_mode(self, value):
209+
self._set(value)
210+
211+
@property
212+
def auth_key(self):
213+
return self._get()
214+
215+
@auth_key.setter
216+
def auth_key(self, value):
217+
self._set(value)
218+
219+
@property
220+
def date(self):
221+
return self._get()
222+
223+
@date.setter
224+
def date(self, value):
225+
self._set(value)
226+
227+
@property
228+
def user_id(self):
229+
return self._get()
230+
231+
@user_id.setter
232+
def user_id(self, value):
233+
self._set(value)
234+
235+
@property
236+
def is_bot(self):
237+
return self._get()
238+
239+
@is_bot.setter
240+
def is_bot(self, value):
241+
self._set(value)

pyrogram/client/storage/schema.sql

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
CREATE TABLE sessions (
2+
dc_id INTEGER PRIMARY KEY,
3+
test_mode INTEGER,
4+
auth_key BLOB,
5+
date INTEGER NOT NULL,
6+
user_id INTEGER,
7+
is_bot INTEGER
8+
);
9+
10+
CREATE TABLE peers (
11+
id INTEGER PRIMARY KEY,
12+
access_hash INTEGER,
13+
type INTEGER NOT NULL,
14+
username TEXT,
15+
phone_number TEXT,
16+
last_update_on INTEGER NOT NULL DEFAULT (CAST(STRFTIME('%s', 'now') AS INTEGER))
17+
);
18+
19+
CREATE TABLE version (
20+
number INTEGER PRIMARY KEY
21+
);
22+
23+
CREATE INDEX idx_peers_id ON peers (id);
24+
CREATE INDEX idx_peers_username ON peers (username);
25+
CREATE INDEX idx_peers_phone_number ON peers (phone_number);
26+
27+
CREATE TRIGGER trg_peers_last_update_on
28+
AFTER UPDATE
29+
ON peers
30+
BEGIN
31+
UPDATE peers
32+
SET last_update_on = CAST(STRFTIME('%s', 'now') AS INTEGER)
33+
WHERE id = NEW.id;
34+
END;

0 commit comments

Comments
 (0)