Skip to content

Commit 5dc33c6

Browse files
committed
add in-memory session storage, refactor session storages, remove mixin
1 parent 9c4e9e1 commit 5dc33c6

File tree

11 files changed

+267
-188
lines changed

11 files changed

+267
-188
lines changed

pyrogram/client/client.py

Lines changed: 58 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,15 @@
5050
from pyrogram.client.handlers import DisconnectHandler
5151
from pyrogram.client.handlers.handler import Handler
5252
from pyrogram.client.methods.password.utils import compute_check
53-
from pyrogram.client.session_storage import BaseSessionConfig
5453
from pyrogram.crypto import AES
5554
from pyrogram.session import Auth, Session
5655
from .dispatcher import Dispatcher
5756
from .ext import utils, Syncer, BaseClient
5857
from .methods import Methods
59-
from .session_storage import SessionDoesNotExist
60-
from .session_storage.json_session_storage import JsonSessionStorage
61-
from .session_storage.string_session_storage import StringSessionStorage
58+
from .session_storage import (
59+
SessionDoesNotExist, SessionStorage, MemorySessionStorage, JsonSessionStorage,
60+
StringSessionStorage
61+
)
6262

6363
log = logging.getLogger(__name__)
6464

@@ -183,7 +183,7 @@ class Client(Methods, BaseClient):
183183
"""
184184

185185
def __init__(self,
186-
session_name: Union[str, BaseSessionConfig],
186+
session_name: Union[str, SessionStorage],
187187
api_id: Union[int, str] = None,
188188
api_hash: str = None,
189189
app_version: str = None,
@@ -209,14 +209,16 @@ def __init__(self,
209209
takeout: bool = None):
210210

211211
if isinstance(session_name, str):
212-
if session_name.startswith(':'):
212+
if session_name == ':memory:':
213+
session_storage = MemorySessionStorage(self)
214+
elif session_name.startswith(':'):
213215
session_storage = StringSessionStorage(self, session_name)
214216
else:
215217
session_storage = JsonSessionStorage(self, session_name)
216-
elif isinstance(session_name, BaseSessionConfig):
217-
session_storage = session_name.session_storage_cls(self, session_name)
218+
elif isinstance(session_name, SessionStorage):
219+
session_storage = session_name
218220
else:
219-
raise RuntimeError('Wrong session_name passed, expected str or BaseSessionConfig subclass')
221+
raise RuntimeError('Wrong session_name passed, expected str or SessionConfig subclass')
220222

221223
super().__init__(session_storage)
222224

@@ -230,7 +232,7 @@ def __init__(self,
230232
self.ipv6 = ipv6
231233
# TODO: Make code consistent, use underscore for private/protected fields
232234
self._proxy = proxy
233-
self.test_mode = test_mode
235+
self.session_storage.test_mode = test_mode
234236
self.phone_number = phone_number
235237
self.phone_code = phone_code
236238
self.password = password
@@ -282,10 +284,10 @@ def start(self):
282284
raise ConnectionError("Client has already been started")
283285

284286
if isinstance(self.session_storage, JsonSessionStorage):
285-
if self.BOT_TOKEN_RE.match(self.session_storage.session_data):
286-
self.is_bot = True
287-
self.bot_token = self.session_storage.session_data
288-
self.session_storage.session_data = self.session_storage.session_data.split(":")[0]
287+
if self.BOT_TOKEN_RE.match(self.session_storage._session_name):
288+
self.session_storage.is_bot = True
289+
self.bot_token = self.session_storage._session_name
290+
self.session_storage._session_name = self.session_storage._session_name.split(":")[0]
289291
warnings.warn('\nYou are using a bot token as session name.\n'
290292
'It will be deprecated in next update, please use session file name to load '
291293
'existing sessions and bot_token argument to create new sessions.',
@@ -297,33 +299,33 @@ def start(self):
297299

298300
self.session = Session(
299301
self,
300-
self.dc_id,
301-
self.auth_key
302+
self.session_storage.dc_id,
303+
self.session_storage.auth_key
302304
)
303305

304306
self.session.start()
305307
self.is_started = True
306308

307309
try:
308-
if self.user_id is None:
310+
if self.session_storage.user_id is None:
309311
if self.bot_token is None:
310312
self.authorize_user()
311313
else:
312-
self.is_bot = True
314+
self.session_storage.is_bot = True
313315
self.authorize_bot()
314316

315317
self.save_session()
316318

317-
if not self.is_bot:
319+
if not self.session_storage.is_bot:
318320
if self.takeout:
319321
self.takeout_id = self.send(functions.account.InitTakeoutSession()).id
320322
log.warning("Takeout session {} initiated".format(self.takeout_id))
321323

322324
now = time.time()
323325

324-
if abs(now - self.date) > Client.OFFLINE_SLEEP:
325-
self.peers_by_username.clear()
326-
self.peers_by_phone.clear()
326+
if abs(now - self.session_storage.date) > Client.OFFLINE_SLEEP:
327+
self.session_storage.peers_by_username.clear()
328+
self.session_storage.peers_by_phone.clear()
327329

328330
self.get_initial_dialogs()
329331
self.get_contacts()
@@ -512,19 +514,20 @@ def authorize_bot(self):
512514
except UserMigrate as e:
513515
self.session.stop()
514516

515-
self.dc_id = e.x
516-
self.auth_key = Auth(self.dc_id, self.test_mode, self.ipv6, self._proxy).create()
517+
self.session_storage.dc_id = e.x
518+
self.session_storage.auth_key = Auth(self.session_storage.dc_id, self.session_storage.test_mode,
519+
self.ipv6, self._proxy).create()
517520

518521
self.session = Session(
519522
self,
520-
self.dc_id,
521-
self.auth_key
523+
self.session_storage.dc_id,
524+
self.session_storage.auth_key
522525
)
523526

524527
self.session.start()
525528
self.authorize_bot()
526529
else:
527-
self.user_id = r.user.id
530+
self.session_storage.user_id = r.user.id
528531

529532
print("Logged in successfully as @{}".format(r.user.username))
530533

@@ -564,19 +567,19 @@ def default_phone_number_callback():
564567
except (PhoneMigrate, NetworkMigrate) as e:
565568
self.session.stop()
566569

567-
self.dc_id = e.x
570+
self.session_storage.dc_id = e.x
568571

569-
self.auth_key = Auth(
570-
self.dc_id,
571-
self.test_mode,
572+
self.session_storage.auth_key = Auth(
573+
self.session_storage.dc_id,
574+
self.session_storage.test_mode,
572575
self.ipv6,
573576
self._proxy
574577
).create()
575578

576579
self.session = Session(
577580
self,
578-
self.dc_id,
579-
self.auth_key
581+
self.session_storage.dc_id,
582+
self.session_storage.auth_key
580583
)
581584

582585
self.session.start()
@@ -752,7 +755,7 @@ def default_recovery_callback(email_pattern: str) -> str:
752755
assert self.send(functions.help.AcceptTermsOfService(terms_of_service.id))
753756

754757
self.password = None
755-
self.user_id = r.user.id
758+
self.session_storage.user_id = r.user.id
756759

757760
print("Logged in successfully as {}".format(r.user.first_name))
758761

@@ -776,13 +779,13 @@ def fetch_peers(self, entities: List[Union[types.User,
776779
access_hash=access_hash
777780
)
778781

779-
self.peers_by_id[user_id] = input_peer
782+
self.session_storage.peers_by_id[user_id] = input_peer
780783

781784
if username is not None:
782-
self.peers_by_username[username.lower()] = input_peer
785+
self.session_storage.peers_by_username[username.lower()] = input_peer
783786

784787
if phone is not None:
785-
self.peers_by_phone[phone] = input_peer
788+
self.session_storage.peers_by_phone[phone] = input_peer
786789

787790
if isinstance(entity, (types.Chat, types.ChatForbidden)):
788791
chat_id = entity.id
@@ -792,7 +795,7 @@ def fetch_peers(self, entities: List[Union[types.User,
792795
chat_id=chat_id
793796
)
794797

795-
self.peers_by_id[peer_id] = input_peer
798+
self.session_storage.peers_by_id[peer_id] = input_peer
796799

797800
if isinstance(entity, (types.Channel, types.ChannelForbidden)):
798801
channel_id = entity.id
@@ -810,10 +813,10 @@ def fetch_peers(self, entities: List[Union[types.User,
810813
access_hash=access_hash
811814
)
812815

813-
self.peers_by_id[peer_id] = input_peer
816+
self.session_storage.peers_by_id[peer_id] = input_peer
814817

815818
if username is not None:
816-
self.peers_by_username[username.lower()] = input_peer
819+
self.session_storage.peers_by_username[username.lower()] = input_peer
817820

818821
def download_worker(self):
819822
name = threading.current_thread().name
@@ -1127,10 +1130,11 @@ def load_config(self):
11271130

11281131
def load_session(self):
11291132
try:
1130-
self.session_storage.load_session()
1133+
self.session_storage.load()
11311134
except SessionDoesNotExist:
11321135
log.info('Could not load session "{}", initiate new one'.format(self.session_name))
1133-
self.auth_key = Auth(self.dc_id, self.test_mode, self.ipv6, self._proxy).create()
1136+
self.session_storage.auth_key = Auth(self.session_storage.dc_id, self.session_storage.test_mode,
1137+
self.ipv6, self._proxy).create()
11341138

11351139
def load_plugins(self):
11361140
if self.plugins.get("enabled", False):
@@ -1237,7 +1241,7 @@ def load_plugins(self):
12371241
log.warning('No plugin loaded from "{}"'.format(root))
12381242

12391243
def save_session(self):
1240-
self.session_storage.save_session()
1244+
self.session_storage.save()
12411245

12421246
def get_initial_dialogs_chunk(self,
12431247
offset_date: int = 0):
@@ -1257,7 +1261,7 @@ def get_initial_dialogs_chunk(self,
12571261
log.warning("get_dialogs flood: waiting {} seconds".format(e.x))
12581262
time.sleep(e.x)
12591263
else:
1260-
log.info("Total peers: {}".format(len(self.peers_by_id)))
1264+
log.info("Total peers: {}".format(len(self.session_storage.peers_by_id)))
12611265
return r
12621266

12631267
def get_initial_dialogs(self):
@@ -1293,7 +1297,7 @@ def resolve_peer(self,
12931297
``KeyError`` in case the peer doesn't exist in the internal database.
12941298
"""
12951299
try:
1296-
return self.peers_by_id[peer_id]
1300+
return self.session_storage.peers_by_id[peer_id]
12971301
except KeyError:
12981302
if type(peer_id) is str:
12991303
if peer_id in ("self", "me"):
@@ -1304,17 +1308,17 @@ def resolve_peer(self,
13041308
try:
13051309
int(peer_id)
13061310
except ValueError:
1307-
if peer_id not in self.peers_by_username:
1311+
if peer_id not in self.session_storage.peers_by_username:
13081312
self.send(
13091313
functions.contacts.ResolveUsername(
13101314
username=peer_id
13111315
)
13121316
)
13131317

1314-
return self.peers_by_username[peer_id]
1318+
return self.session_storage.peers_by_username[peer_id]
13151319
else:
13161320
try:
1317-
return self.peers_by_phone[peer_id]
1321+
return self.session_storage.peers_by_phone[peer_id]
13181322
except KeyError:
13191323
raise PeerIdInvalid
13201324

@@ -1341,7 +1345,7 @@ def resolve_peer(self,
13411345
)
13421346

13431347
try:
1344-
return self.peers_by_id[peer_id]
1348+
return self.session_storage.peers_by_id[peer_id]
13451349
except KeyError:
13461350
raise PeerIdInvalid
13471351

@@ -1411,7 +1415,7 @@ def save_file(self,
14111415
file_id = file_id or self.rnd_id()
14121416
md5_sum = md5() if not is_big and not is_missing_part else None
14131417

1414-
session = Session(self, self.dc_id, self.auth_key, is_media=True)
1418+
session = Session(self, self.session_storage.dc_id, self.session_storage.auth_key, is_media=True)
14151419
session.start()
14161420

14171421
try:
@@ -1492,7 +1496,7 @@ def get_file(self,
14921496
session = self.media_sessions.get(dc_id, None)
14931497

14941498
if session is None:
1495-
if dc_id != self.dc_id:
1499+
if dc_id != self.session_storage.dc_id:
14961500
exported_auth = self.send(
14971501
functions.auth.ExportAuthorization(
14981502
dc_id=dc_id
@@ -1502,7 +1506,7 @@ def get_file(self,
15021506
session = Session(
15031507
self,
15041508
dc_id,
1505-
Auth(dc_id, self.test_mode, self.ipv6, self._proxy).create(),
1509+
Auth(dc_id, self.session_storage.test_mode, self.ipv6, self._proxy).create(),
15061510
is_media=True
15071511
)
15081512

@@ -1520,7 +1524,7 @@ def get_file(self,
15201524
session = Session(
15211525
self,
15221526
dc_id,
1523-
self.auth_key,
1527+
self.session_storage.auth_key,
15241528
is_media=True
15251529
)
15261530

@@ -1588,7 +1592,7 @@ def get_file(self,
15881592
cdn_session = Session(
15891593
self,
15901594
r.dc_id,
1591-
Auth(r.dc_id, self.test_mode, self.ipv6, self._proxy).create(),
1595+
Auth(r.dc_id, self.session_storage.test_mode, self.ipv6, self._proxy).create(),
15921596
is_media=True,
15931597
is_cdn=True
15941598
)

pyrogram/client/ext/base_client.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@
2424
from pyrogram import __version__
2525
from ..style import Markdown, HTML
2626
from ...session.internals import MsgId
27-
from ..session_storage import SessionStorageMixin, BaseSessionStorage
27+
from ..session_storage import SessionStorage
2828

2929

30-
class BaseClient(SessionStorageMixin):
30+
class BaseClient:
3131
class StopTransmission(StopIteration):
3232
pass
3333

@@ -68,14 +68,14 @@ class StopTransmission(StopIteration):
6868
13: "video_note"
6969
}
7070

71-
def __init__(self, session_storage: BaseSessionStorage):
71+
def __init__(self, session_storage: SessionStorage):
7272
self.session_storage = session_storage
7373

7474
self.rnd_id = MsgId
7575
self.channels_pts = {}
7676

77-
self.markdown = Markdown(self.peers_by_id)
78-
self.html = HTML(self.peers_by_id)
77+
self.markdown = Markdown(self.session_storage.peers_by_id)
78+
self.html = HTML(self.session_storage.peers_by_id)
7979

8080
self.session = None
8181
self.media_sessions = {}

pyrogram/client/ext/syncer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,9 @@ def worker(cls):
8181

8282
@classmethod
8383
def sync(cls, client):
84-
client.date = int(time.time())
84+
client.session_storage.date = int(time.time())
8585
try:
86-
client.session_storage.save_session(sync=True)
86+
client.session_storage.save(sync=True)
8787
except Exception as e:
8888
log.critical(e, exc_info=True)
8989
else:

pyrogram/client/methods/contacts/get_contacts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,5 +44,5 @@ def get_contacts(self):
4444
log.warning("get_contacts flood: waiting {} seconds".format(e.x))
4545
time.sleep(e.x)
4646
else:
47-
log.info("Total contacts: {}".format(len(self.peers_by_phone)))
47+
log.info("Total contacts: {}".format(len(self.session_storage.peers_by_phone)))
4848
return [pyrogram.User._parse(self, user) for user in contacts.users]

pyrogram/client/session_storage/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,7 @@
1616
# You should have received a copy of the GNU Lesser General Public License
1717
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
1818

19-
from .session_storage_mixin import SessionStorageMixin
20-
from .base_session_storage import BaseSessionStorage, BaseSessionConfig, SessionDoesNotExist
19+
from .abstract import SessionStorage, SessionDoesNotExist
20+
from .memory import MemorySessionStorage
21+
from .json import JsonSessionStorage
22+
from .string import StringSessionStorage

0 commit comments

Comments
 (0)