|
17 | 17 | # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. |
18 | 18 |
|
19 | 19 | import base64 |
20 | | -import inspect |
21 | 20 | import logging |
22 | 21 | import sqlite3 |
23 | 22 | import struct |
24 | | -import time |
25 | | -from pathlib import Path |
26 | | -from threading import Lock |
27 | | -from typing import List, Tuple |
28 | 23 |
|
29 | | -from pyrogram.api import types |
30 | | -from pyrogram.client.storage.storage import Storage |
| 24 | +from .sqlite_storage import SQLiteStorage |
31 | 25 |
|
32 | 26 | log = logging.getLogger(__name__) |
33 | 27 |
|
34 | 28 |
|
35 | | -class MemoryStorage(Storage): |
36 | | - SCHEMA_VERSION = 1 |
37 | | - USERNAME_TTL = 8 * 60 * 60 |
38 | | - SESSION_STRING_FMT = ">B?256sI?" |
39 | | - SESSION_STRING_SIZE = 351 |
40 | | - |
| 29 | +class MemoryStorage(SQLiteStorage): |
41 | 30 | def __init__(self, name: str): |
42 | 31 | super().__init__(name) |
43 | 32 |
|
44 | | - self.conn = None # type: sqlite3.Connection |
45 | | - self.lock = Lock() |
46 | | - |
47 | | - def create(self): |
48 | | - with self.lock, self.conn: |
49 | | - with open(str(Path(__file__).parent / "schema.sql"), "r") as schema: |
50 | | - self.conn.executescript(schema.read()) |
51 | | - |
52 | | - self.conn.execute( |
53 | | - "INSERT INTO version VALUES (?)", |
54 | | - (self.SCHEMA_VERSION,) |
55 | | - ) |
56 | | - |
57 | | - self.conn.execute( |
58 | | - "INSERT INTO sessions VALUES (?, ?, ?, ?, ?, ?)", |
59 | | - (1, None, None, 0, None, None) |
60 | | - ) |
61 | | - |
62 | | - def _import_session_string(self, session_string: str): |
63 | | - decoded = base64.urlsafe_b64decode(session_string + "=" * (-len(session_string) % 4)) |
64 | | - return struct.unpack(self.SESSION_STRING_FMT, decoded) |
65 | | - |
66 | | - def export_session_string(self): |
67 | | - packed = struct.pack( |
68 | | - self.SESSION_STRING_FMT, |
69 | | - self.dc_id, |
70 | | - self.test_mode, |
71 | | - self.auth_key, |
72 | | - self.user_id, |
73 | | - self.is_bot |
74 | | - ) |
75 | | - |
76 | | - return base64.urlsafe_b64encode(packed).decode().rstrip("=") |
77 | | - |
78 | | - # noinspection PyAttributeOutsideInit |
79 | 33 | def open(self): |
80 | 34 | self.conn = sqlite3.connect(":memory:", check_same_thread=False) |
81 | 35 | self.create() |
82 | 36 |
|
83 | 37 | if self.name != ":memory:": |
84 | | - imported_session_string = self._import_session_string(self.name) |
85 | | - |
86 | | - self.dc_id, self.test_mode, self.auth_key, self.user_id, self.is_bot = imported_session_string |
87 | | - self.date = 0 |
88 | | - |
89 | | - # noinspection PyAttributeOutsideInit |
90 | | - def save(self): |
91 | | - self.date = int(time.time()) |
92 | | - |
93 | | - with self.lock: |
94 | | - self.conn.commit() |
95 | | - |
96 | | - def close(self): |
97 | | - with self.lock: |
98 | | - self.conn.close() |
99 | | - |
100 | | - def destroy(self): |
101 | | - pass |
102 | | - |
103 | | - def update_peers(self, peers: List[Tuple[int, int, str, str, str]]): |
104 | | - with self.lock: |
105 | | - self.conn.executemany( |
106 | | - "REPLACE INTO peers (id, access_hash, type, username, phone_number)" |
107 | | - "VALUES (?, ?, ?, ?, ?)", |
108 | | - peers |
109 | | - ) |
110 | | - |
111 | | - def clear_peers(self): |
112 | | - with self.lock, self.conn: |
113 | | - self.conn.execute( |
114 | | - "DELETE FROM peers" |
115 | | - ) |
116 | | - |
117 | | - @staticmethod |
118 | | - def _get_input_peer(peer_id: int, access_hash: int, peer_type: str): |
119 | | - if peer_type in ["user", "bot"]: |
120 | | - return types.InputPeerUser( |
121 | | - user_id=peer_id, |
122 | | - access_hash=access_hash |
| 38 | + dc_id, test_mode, auth_key, user_id, is_bot = struct.unpack( |
| 39 | + self.SESSION_STRING_FORMAT, |
| 40 | + base64.urlsafe_b64decode( |
| 41 | + self.name + "=" * (-len(self.name) % 4) |
| 42 | + ) |
123 | 43 | ) |
124 | 44 |
|
125 | | - if peer_type == "group": |
126 | | - return types.InputPeerChat( |
127 | | - chat_id=-peer_id |
128 | | - ) |
129 | | - |
130 | | - if peer_type in ["channel", "supergroup"]: |
131 | | - return types.InputPeerChannel( |
132 | | - channel_id=int(str(peer_id)[4:]), |
133 | | - access_hash=access_hash |
134 | | - ) |
135 | | - |
136 | | - raise ValueError("Invalid peer type: {}".format(peer_type)) |
137 | | - |
138 | | - def get_peer_by_id(self, peer_id: int): |
139 | | - r = self.conn.execute( |
140 | | - "SELECT id, access_hash, type FROM peers WHERE id = ?", |
141 | | - (peer_id,) |
142 | | - ).fetchone() |
143 | | - |
144 | | - if r is None: |
145 | | - raise KeyError("ID not found: {}".format(peer_id)) |
| 45 | + self.dc_id(dc_id) |
| 46 | + self.test_mode(test_mode) |
| 47 | + self.auth_key(auth_key) |
| 48 | + self.user_id(user_id) |
| 49 | + self.is_bot(is_bot) |
| 50 | + self.date(0) |
146 | 51 |
|
147 | | - return self._get_input_peer(*r) |
148 | | - |
149 | | - def get_peer_by_username(self, username: str): |
150 | | - r = self.conn.execute( |
151 | | - "SELECT id, access_hash, type, last_update_on FROM peers WHERE username = ?", |
152 | | - (username,) |
153 | | - ).fetchone() |
154 | | - |
155 | | - if r is None: |
156 | | - raise KeyError("Username not found: {}".format(username)) |
157 | | - |
158 | | - if abs(time.time() - r[3]) > self.USERNAME_TTL: |
159 | | - raise KeyError("Username expired: {}".format(username)) |
160 | | - |
161 | | - return self._get_input_peer(*r[:3]) |
162 | | - |
163 | | - def get_peer_by_phone_number(self, phone_number: str): |
164 | | - r = self.conn.execute( |
165 | | - "SELECT id, access_hash, type FROM peers WHERE phone_number = ?", |
166 | | - (phone_number,) |
167 | | - ).fetchone() |
168 | | - |
169 | | - if r is None: |
170 | | - raise KeyError("Phone number not found: {}".format(phone_number)) |
171 | | - |
172 | | - return self._get_input_peer(*r) |
173 | | - |
174 | | - @property |
175 | | - def peers_count(self): |
176 | | - return self.conn.execute( |
177 | | - "SELECT COUNT(*) FROM peers" |
178 | | - ).fetchone()[0] |
179 | | - |
180 | | - def _get(self): |
181 | | - attr = inspect.stack()[1].function |
182 | | - |
183 | | - return self.conn.execute( |
184 | | - "SELECT {} FROM sessions".format(attr) |
185 | | - ).fetchone()[0] |
186 | | - |
187 | | - def _set(self, value): |
188 | | - attr = inspect.stack()[1].function |
189 | | - |
190 | | - with self.lock, self.conn: |
191 | | - self.conn.execute( |
192 | | - "UPDATE sessions SET {} = ?".format(attr), |
193 | | - (value,) |
194 | | - ) |
195 | | - |
196 | | - @property |
197 | | - def dc_id(self): |
198 | | - return self._get() |
199 | | - |
200 | | - @dc_id.setter |
201 | | - def dc_id(self, value): |
202 | | - self._set(value) |
203 | | - |
204 | | - @property |
205 | | - def test_mode(self): |
206 | | - return self._get() |
207 | | - |
208 | | - @test_mode.setter |
209 | | - def test_mode(self, value): |
210 | | - self._set(value) |
211 | | - |
212 | | - @property |
213 | | - def auth_key(self): |
214 | | - return self._get() |
215 | | - |
216 | | - @auth_key.setter |
217 | | - def auth_key(self, value): |
218 | | - self._set(value) |
219 | | - |
220 | | - @property |
221 | | - def date(self): |
222 | | - return self._get() |
223 | | - |
224 | | - @date.setter |
225 | | - def date(self, value): |
226 | | - self._set(value) |
227 | | - |
228 | | - @property |
229 | | - def user_id(self): |
230 | | - return self._get() |
231 | | - |
232 | | - @user_id.setter |
233 | | - def user_id(self, value): |
234 | | - self._set(value) |
235 | | - |
236 | | - @property |
237 | | - def is_bot(self): |
238 | | - return self._get() |
239 | | - |
240 | | - @is_bot.setter |
241 | | - def is_bot(self, value): |
242 | | - self._set(value) |
| 52 | + def delete(self): |
| 53 | + pass |
0 commit comments