Skip to content

Commit 0097df2

Browse files
committed
Rework File and Memory storage to accommodate the new abstract class
1 parent 1efce33 commit 0097df2

File tree

2 files changed

+24
-222
lines changed

2 files changed

+24
-222
lines changed

pyrogram/client/storage/file_storage.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,34 +22,29 @@
2222
import os
2323
import sqlite3
2424
from pathlib import Path
25-
from threading import Lock
2625

27-
from .memory_storage import MemoryStorage
26+
from .sqlite_storage import SQLiteStorage
2827

2928
log = logging.getLogger(__name__)
3029

3130

32-
class FileStorage(MemoryStorage):
31+
class FileStorage(SQLiteStorage):
3332
FILE_EXTENSION = ".session"
3433

3534
def __init__(self, name: str, workdir: Path):
3635
super().__init__(name)
3736

38-
self.workdir = workdir
3937
self.database = workdir / (self.name + self.FILE_EXTENSION)
40-
self.conn = None # type: sqlite3.Connection
41-
self.lock = Lock()
4238

43-
# noinspection PyAttributeOutsideInit
4439
def migrate_from_json(self, session_json: dict):
4540
self.open()
4641

47-
self.dc_id = session_json["dc_id"]
48-
self.test_mode = session_json["test_mode"]
49-
self.auth_key = base64.b64decode("".join(session_json["auth_key"]))
50-
self.user_id = session_json["user_id"]
51-
self.date = session_json.get("date", 0)
52-
self.is_bot = session_json.get("is_bot", False)
42+
self.dc_id(session_json["dc_id"])
43+
self.test_mode(session_json["test_mode"])
44+
self.auth_key(base64.b64decode("".join(session_json["auth_key"])))
45+
self.user_id(session_json["user_id"])
46+
self.date(session_json.get("date", 0))
47+
self.is_bot(session_json.get("is_bot", False))
5348

5449
peers_by_id = session_json.get("peers_by_id", {})
5550
peers_by_phone = session_json.get("peers_by_phone", {})
@@ -98,11 +93,7 @@ def open(self):
9893
if Path(path.name + ".OLD").is_file():
9994
log.warning('Old session file detected: "{}.OLD". You can remove this file now'.format(path.name))
10095

101-
self.conn = sqlite3.connect(
102-
str(path),
103-
timeout=1,
104-
check_same_thread=False
105-
)
96+
self.conn = sqlite3.connect(str(path), timeout=1, check_same_thread=False)
10697

10798
if not file_exists:
10899
self.create()

pyrogram/client/storage/memory_storage.py

Lines changed: 15 additions & 204 deletions
Original file line numberDiff line numberDiff line change
@@ -17,226 +17,37 @@
1717
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
1818

1919
import base64
20-
import inspect
2120
import logging
2221
import sqlite3
2322
import struct
24-
import time
25-
from pathlib import Path
26-
from threading import Lock
27-
from typing import List, Tuple
2823

29-
from pyrogram.api import types
30-
from pyrogram.client.storage.storage import Storage
24+
from .sqlite_storage import SQLiteStorage
3125

3226
log = logging.getLogger(__name__)
3327

3428

35-
class MemoryStorage(Storage):
36-
SCHEMA_VERSION = 1
37-
USERNAME_TTL = 8 * 60 * 60
38-
SESSION_STRING_FMT = ">B?256sI?"
39-
SESSION_STRING_SIZE = 351
40-
29+
class MemoryStorage(SQLiteStorage):
4130
def __init__(self, name: str):
4231
super().__init__(name)
4332

44-
self.conn = None # type: sqlite3.Connection
45-
self.lock = Lock()
46-
47-
def create(self):
48-
with self.lock, self.conn:
49-
with open(str(Path(__file__).parent / "schema.sql"), "r") as schema:
50-
self.conn.executescript(schema.read())
51-
52-
self.conn.execute(
53-
"INSERT INTO version VALUES (?)",
54-
(self.SCHEMA_VERSION,)
55-
)
56-
57-
self.conn.execute(
58-
"INSERT INTO sessions VALUES (?, ?, ?, ?, ?, ?)",
59-
(1, None, None, 0, None, None)
60-
)
61-
62-
def _import_session_string(self, session_string: str):
63-
decoded = base64.urlsafe_b64decode(session_string + "=" * (-len(session_string) % 4))
64-
return struct.unpack(self.SESSION_STRING_FMT, decoded)
65-
66-
def export_session_string(self):
67-
packed = struct.pack(
68-
self.SESSION_STRING_FMT,
69-
self.dc_id,
70-
self.test_mode,
71-
self.auth_key,
72-
self.user_id,
73-
self.is_bot
74-
)
75-
76-
return base64.urlsafe_b64encode(packed).decode().rstrip("=")
77-
78-
# noinspection PyAttributeOutsideInit
7933
def open(self):
8034
self.conn = sqlite3.connect(":memory:", check_same_thread=False)
8135
self.create()
8236

8337
if self.name != ":memory:":
84-
imported_session_string = self._import_session_string(self.name)
85-
86-
self.dc_id, self.test_mode, self.auth_key, self.user_id, self.is_bot = imported_session_string
87-
self.date = 0
88-
89-
# noinspection PyAttributeOutsideInit
90-
def save(self):
91-
self.date = int(time.time())
92-
93-
with self.lock:
94-
self.conn.commit()
95-
96-
def close(self):
97-
with self.lock:
98-
self.conn.close()
99-
100-
def destroy(self):
101-
pass
102-
103-
def update_peers(self, peers: List[Tuple[int, int, str, str, str]]):
104-
with self.lock:
105-
self.conn.executemany(
106-
"REPLACE INTO peers (id, access_hash, type, username, phone_number)"
107-
"VALUES (?, ?, ?, ?, ?)",
108-
peers
109-
)
110-
111-
def clear_peers(self):
112-
with self.lock, self.conn:
113-
self.conn.execute(
114-
"DELETE FROM peers"
115-
)
116-
117-
@staticmethod
118-
def _get_input_peer(peer_id: int, access_hash: int, peer_type: str):
119-
if peer_type in ["user", "bot"]:
120-
return types.InputPeerUser(
121-
user_id=peer_id,
122-
access_hash=access_hash
38+
dc_id, test_mode, auth_key, user_id, is_bot = struct.unpack(
39+
self.SESSION_STRING_FORMAT,
40+
base64.urlsafe_b64decode(
41+
self.name + "=" * (-len(self.name) % 4)
42+
)
12343
)
12444

125-
if peer_type == "group":
126-
return types.InputPeerChat(
127-
chat_id=-peer_id
128-
)
129-
130-
if peer_type in ["channel", "supergroup"]:
131-
return types.InputPeerChannel(
132-
channel_id=int(str(peer_id)[4:]),
133-
access_hash=access_hash
134-
)
135-
136-
raise ValueError("Invalid peer type: {}".format(peer_type))
137-
138-
def get_peer_by_id(self, peer_id: int):
139-
r = self.conn.execute(
140-
"SELECT id, access_hash, type FROM peers WHERE id = ?",
141-
(peer_id,)
142-
).fetchone()
143-
144-
if r is None:
145-
raise KeyError("ID not found: {}".format(peer_id))
45+
self.dc_id(dc_id)
46+
self.test_mode(test_mode)
47+
self.auth_key(auth_key)
48+
self.user_id(user_id)
49+
self.is_bot(is_bot)
50+
self.date(0)
14651

147-
return self._get_input_peer(*r)
148-
149-
def get_peer_by_username(self, username: str):
150-
r = self.conn.execute(
151-
"SELECT id, access_hash, type, last_update_on FROM peers WHERE username = ?",
152-
(username,)
153-
).fetchone()
154-
155-
if r is None:
156-
raise KeyError("Username not found: {}".format(username))
157-
158-
if abs(time.time() - r[3]) > self.USERNAME_TTL:
159-
raise KeyError("Username expired: {}".format(username))
160-
161-
return self._get_input_peer(*r[:3])
162-
163-
def get_peer_by_phone_number(self, phone_number: str):
164-
r = self.conn.execute(
165-
"SELECT id, access_hash, type FROM peers WHERE phone_number = ?",
166-
(phone_number,)
167-
).fetchone()
168-
169-
if r is None:
170-
raise KeyError("Phone number not found: {}".format(phone_number))
171-
172-
return self._get_input_peer(*r)
173-
174-
@property
175-
def peers_count(self):
176-
return self.conn.execute(
177-
"SELECT COUNT(*) FROM peers"
178-
).fetchone()[0]
179-
180-
def _get(self):
181-
attr = inspect.stack()[1].function
182-
183-
return self.conn.execute(
184-
"SELECT {} FROM sessions".format(attr)
185-
).fetchone()[0]
186-
187-
def _set(self, value):
188-
attr = inspect.stack()[1].function
189-
190-
with self.lock, self.conn:
191-
self.conn.execute(
192-
"UPDATE sessions SET {} = ?".format(attr),
193-
(value,)
194-
)
195-
196-
@property
197-
def dc_id(self):
198-
return self._get()
199-
200-
@dc_id.setter
201-
def dc_id(self, value):
202-
self._set(value)
203-
204-
@property
205-
def test_mode(self):
206-
return self._get()
207-
208-
@test_mode.setter
209-
def test_mode(self, value):
210-
self._set(value)
211-
212-
@property
213-
def auth_key(self):
214-
return self._get()
215-
216-
@auth_key.setter
217-
def auth_key(self, value):
218-
self._set(value)
219-
220-
@property
221-
def date(self):
222-
return self._get()
223-
224-
@date.setter
225-
def date(self, value):
226-
self._set(value)
227-
228-
@property
229-
def user_id(self):
230-
return self._get()
231-
232-
@user_id.setter
233-
def user_id(self, value):
234-
self._set(value)
235-
236-
@property
237-
def is_bot(self):
238-
return self._get()
239-
240-
@is_bot.setter
241-
def is_bot(self, value):
242-
self._set(value)
52+
def delete(self):
53+
pass

0 commit comments

Comments
 (0)