From 72183b294fc45e8a46612ec9c1c4097b6334766e Mon Sep 17 00:00:00 2001 From: wulan17 Date: Mon, 23 Oct 2023 19:57:07 +0700 Subject: [PATCH] Pyrofork: Storage: MongoStorage: Save fragment username(s) to sessions database Signed-off-by: wulan17 --- pyrogram/client.py | 8 ++++++++ pyrogram/storage/mongo_storage.py | 28 +++++++++++++++++++++++++++- pyrogram/storage/storage.py | 3 +++ 3 files changed, 38 insertions(+), 1 deletion(-) diff --git a/pyrogram/client.py b/pyrogram/client.py index e59f20595..f694b4f7d 100644 --- a/pyrogram/client.py +++ b/pyrogram/client.py @@ -508,6 +508,7 @@ def set_parse_mode(self, parse_mode: Optional["enums.ParseMode"]): async def fetch_peers(self, peers: List[Union[raw.types.User, raw.types.Chat, raw.types.Channel]]) -> bool: is_min = False parsed_peers = [] + usernames = [] for peer in peers: if getattr(peer, "min", False): @@ -525,6 +526,9 @@ async def fetch_peers(self, peers: List[Union[raw.types.User, raw.types.Chat, ra else peer.usernames[0].username.lower() if peer.usernames else None ) + if peer.usernames is not None and len(peer.usernames) > 1: + for uname in peer.usernames: + usernames.append((peer.id, uname.username.lower())) phone_number = peer.phone peer_type = "bot" if peer.bot else "user" elif isinstance(peer, (raw.types.Chat, raw.types.ChatForbidden)): @@ -539,6 +543,9 @@ async def fetch_peers(self, peers: List[Union[raw.types.User, raw.types.Chat, ra else peer.usernames[0].username.lower() if peer.usernames else None ) + if peer.usernames is not None and len(peer.usernames) > 1: + for uname in peer.usernames: + usernames.append((peer.id, uname.username.lower())) peer_type = "channel" if peer.broadcast else "supergroup" elif isinstance(peer, raw.types.ChannelForbidden): peer_id = utils.get_channel_id(peer.id) @@ -550,6 +557,7 @@ async def fetch_peers(self, peers: List[Union[raw.types.User, raw.types.Chat, ra parsed_peers.append((peer_id, access_hash, peer_type, username, phone_number)) await self.storage.update_peers(parsed_peers) + await self.storage.update_usernames(usernames) return is_min diff --git a/pyrogram/storage/mongo_storage.py b/pyrogram/storage/mongo_storage.py index 5240917e5..2ffe3f96c 100644 --- a/pyrogram/storage/mongo_storage.py +++ b/pyrogram/storage/mongo_storage.py @@ -56,6 +56,7 @@ def __init__( self.database = database self._peer = database['peers'] self._session = database['session'] + self._usernames = database['usernames'] self._remove_peers = remove_peers async def open(self): @@ -121,6 +122,24 @@ async def update_peers(self, peers: List[Tuple[int, int, str, str, str]]): bulk ) + async def update_usernames(self, usernames: List[Tuple[int, str]]): + s = int(time.time()) + bulk = [ + UpdateOne( + {'_id': i[1]}, + {'$set': { + 'peer_id': i[0], + 'last_update_on': s + }}, + upsert=True + ) for i in usernames + ] + if not bulk: + return + await self._usernames.bulk_write( + bulk + ) + async def get_peer_by_id(self, peer_id: int): # id, access_hash, type r = await self._peer.find_one({'_id': peer_id}, {'_id': 1, 'access_hash': 1, 'type': 1}) @@ -134,7 +153,14 @@ async def get_peer_by_username(self, username: str): {'_id': 1, 'access_hash': 1, 'type': 1, 'last_update_on': 1}) if r is None: - raise KeyError(f"Username not found: {username}") + r2 = await self._usernames.find_one({'_id': username}, + {'peer_id': 1, 'last_update_on': 1}) + if r2 is None: + raise KeyError(f"Username not found: {username}") + r = await self._peer.find_one({'_id': r2['peer_id']}, + {'_id': 1, 'access_hash': 1, 'type': 1, 'last_update_on': 1}) + if r is None: + raise KeyError(f"Username not found: {username}") if abs(time.time() - r['last_update_on']) > self.USERNAME_TTL: raise KeyError(f"Username expired: {username}") diff --git a/pyrogram/storage/storage.py b/pyrogram/storage/storage.py index 0689b6826..6ec81723e 100644 --- a/pyrogram/storage/storage.py +++ b/pyrogram/storage/storage.py @@ -47,6 +47,9 @@ async def delete(self): async def update_peers(self, peers: List[Tuple[int, int, str, str, str]]): raise NotImplementedError + async def update_usernames(self, usernames: List[Tuple[int, str]]): + raise NotImplementedError + async def get_peer_by_id(self, peer_id: int): raise NotImplementedError