Skip to content

Commit

Permalink
wsocket: use queue for rechecking accounts
Browse files Browse the repository at this point in the history
  • Loading branch information
afalaleev committed Jan 15, 2025
1 parent dab78f0 commit b99f8c1
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 23 deletions.
42 changes: 24 additions & 18 deletions common/solana_rpc/ws_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import time
from dataclasses import dataclass
from typing import Union, Sequence, Literal, Generic, TypeVar, Final, Any
from collections import deque

import aiohttp as _ws
import pydantic as _pyd
Expand Down Expand Up @@ -73,12 +74,11 @@ class _SolWsObjInfo(Generic[_SolWsObjKey, _SolWsObj]):
sub_id: int | None
key: _SolWsObjKey | None
obj: _SolWsObj | None
last_nsec: int


class _SolWsSession(Generic[_SolWsObjKey, _SolWsObj]):
_ObjInfo = _SolWsObjInfo[_SolWsObjKey, _SolWsObj]
_empty_info: Final[_ObjInfo] = _SolWsObjInfo(None, None, None, None, 0)
_empty_info: Final[_ObjInfo] = _SolWsObjInfo(None, None, None, None)

def __init__(self, cfg: Config, sol_client: SolClient, *, ws_endpoint: HttpStrOrURL | None = None) -> None:
self._cfg = cfg
Expand Down Expand Up @@ -189,7 +189,7 @@ async def _wait(self, timeout_sec: float | None) -> None:
assert item.result not in self._sub_dict, f"subscription {item.result} for {key} already exists?"

self._sub_dict[item.result] = key
info = _SolWsObjInfo(key=key, obj=info.obj, req_id=item.id, sub_id=item.result, last_nsec=now)
info = _SolWsObjInfo(key=key, obj=info.obj, req_id=item.id, sub_id=item.result)
self._obj_dict[key] = info
# _LOG.debug("got subscription %s for %s", item.result, key)
else:
Expand All @@ -208,8 +208,7 @@ async def _sub_obj(self, key: _SolWsObjKey, obj: _SolWsObj, commit: SolCommit) -
return

req_id = self._get_next_id()
now = time.monotonic_ns()
info = _SolWsObjInfo(key=key, obj=obj, req_id=req_id, sub_id=None, last_nsec=now)
info = _SolWsObjInfo(key=key, obj=obj, req_id=req_id, sub_id=None)
self._req_dict[req_id] = key
self._obj_dict[key] = info

Expand Down Expand Up @@ -297,19 +296,25 @@ def _new_unsub_request(self, req_id: int, sub_id: int) -> _SolWsSendData:
return _SoldersUnsubTxSig(sub_id, req_id)


@dataclass(frozen=True)
class _RecheckAcctInfo:
key: SolPubKey
insert_nsec: int


class SolWatchAccountSession(_SolWsSession[SolPubKey, SolAccountModel]):
_AcctInfo = _SolWsObjInfo[SolPubKey, SolAccountModel]

def __init__(self, *args, **kwargs) -> None:
commit = kwargs.pop("commit", SolCommit.Confirmed)
update_nsec = kwargs.pop("force_update_sec", 0) * (10**9)
update_nsec = int(kwargs.pop("force_check_sec", 0) * (10**9))
super().__init__(*args, **kwargs)
self._commit = commit
self._force_update_nsec = update_nsec
self._reconnect_future: asyncio.Future[Any] | None = None
self._update_future: asyncio.Future[Any] | None = None
self._recheck_queue: deque[_RecheckAcctInfo] = deque()
self._chg_key_set: set[SolPubKey] = set()
self._last_nsec: int = 0

async def update(self, *, timeout_sec: float = 0.0001) -> None:
if self._update_future:
Expand Down Expand Up @@ -341,15 +346,13 @@ def get_account(self, addr: SolPubKey) -> SolAccountModel | None:
return info.obj if (info := self._obj_dict.get(addr, None)) else None

def pop_changed_key_list(self) -> Sequence[SolPubKey]:
key_list, self._chg_key_set = tuple(self._chg_key_set), set()
if self._force_update_nsec:
now = time.monotonic_ns()
if key_list:
self._last_nsec = now
elif (check_nsec := now - self._force_update_nsec) < self._last_nsec:
self._last_nsec = now
key_list = tuple([x.key for x in self._obj_dict.values() if x.last_nsec < check_nsec])
return key_list
key_list, self._chg_key_set = list(self._chg_key_set), set()

last_nsec = time.monotonic_ns() - self._force_update_nsec
while self._recheck_queue and (self._recheck_queue[0].insert_nsec <= last_nsec):
key_list.append(self._recheck_queue.popleft().key)

return tuple(key_list)

async def _on_close(self) -> None:
if self._reconnect_future:
Expand All @@ -374,11 +377,14 @@ async def _on_close(self) -> None:

def _on_sub_notif(self, info: _AcctInfo, data: _SoldersAcctNotif, now: int) -> None:
acct = SolAccountModel.from_raw(info.key, data.result.value)
info = _SolWsObjInfo(req_id=info.req_id, sub_id=info.sub_id, key=info.key, obj=acct, last_nsec=now)
info = _SolWsObjInfo(req_id=info.req_id, sub_id=info.sub_id, key=info.key, obj=acct)
self._obj_dict[info.key] = info
self._chg_key_set.add(info.key)
self._sub_dict[info.sub_id] = info.key

if info.key not in self._chg_key_set:
self._chg_key_set.add(info.key)
self._recheck_queue.append(_RecheckAcctInfo(info.key, now))

def _new_sub_request(self, info: _AcctInfo, commit: SolCommit) -> _SolWsSendData:
cfg = _SoldersAcctCfg(encoding=_SoldersAcctEnc.Base64, commitment=commit.to_rpc_commit())
return _SoldersSubAcct(info.key, cfg, info.req_id)
Expand Down
5 changes: 2 additions & 3 deletions proxy/executor/skd_tree_parser.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
import time
from typing import Generator, Final

from common.config.constants import ONE_BLOCK_SEC
Expand All @@ -17,7 +16,7 @@


class NeonSkdTreeParser(ExecutorComponent):
_sleep_sec: Final[float] = ONE_BLOCK_SEC * 3
_recheck_sec: Final[float] = ONE_BLOCK_SEC * 3

def __init__(self, server: ExecutorServerAbc, payer: NeonAddress, nonce: int) -> None:
super().__init__(server)
Expand All @@ -27,7 +26,7 @@ def __init__(self, server: ExecutorServerAbc, payer: NeonAddress, nonce: int) ->
self._tree: NeonSkdTreeModel | None = None
self._neon_tx_hash = EthTxHash.default()

self._watch_session = SolWatchAccountSession(self._cfg, self._sol_client, force_update_sec=3)
self._watch_session = SolWatchAccountSession(self._cfg, self._sol_client, force_check_sec=self._recheck_sec)

async def start(self) -> None:
await self._watch_session.connect()
Expand Down
4 changes: 2 additions & 2 deletions proxy/mempool/transaction_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Final, Sequence

from common.config.config import Config
from common.config.constants import ONE_BLOCK_SEC
from common.config.constants import ONE_BLOCK_SEC, MIN_FINALIZE_SEC
from common.ethereum.hash import EthAddress
from common.neon.address import NeonAddress
from common.neon.transaction_model import NeonTxModel
Expand Down Expand Up @@ -406,7 +406,7 @@ def __init__(
global_tx_dict: MpTxDict,
) -> None:
self._core_api_client = core_api_client
self._watch_session = SolWatchAccountSession(cfg, sol_client, force_update_sec=ONE_BLOCK_SEC * 40)
self._watch_session = SolWatchAccountSession(cfg, sol_client, force_check_sec=MIN_FINALIZE_SEC)
self._capacity: Final[int] = cfg.mp_capacity
self._capacity_high_watermark: Final[int] = int(self._capacity * cfg.mp_capacity_high_watermark)
self._eviction_timeout_sec = cfg.mp_eviction_timeout_sec
Expand Down

0 comments on commit b99f8c1

Please sign in to comment.