Skip to content

Commit

Permalink
Fix #688: PredictoorContract deepcopy() causes RecursionError (#689)
Browse files Browse the repository at this point in the history
* Fix #688: [Bug, pdr bot] PredictoorContract deepcopy() causes RecursionError

* revert contract: remove deepcopy test

* Fix via two Web3Config objects

* black

* Fixes

* Fix allowance tracking

* More checks

* Add set_token method to PredictoorContract

* Call set_token with updated pp

* Add mock owner

* Type annotation

* Add missing function to mock

* Formatting

* Add and use copy_with_pk function

* Fix type

* Fix

* Add new pk

* Fix mock

* Formatting

* Mypy ignore

* Formatting

* Remove submodule

---------

Co-authored-by: trizin <[email protected]>
  • Loading branch information
trentmc and trizin authored Feb 28, 2024
1 parent 50119b7 commit 50fa66c
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 42 deletions.
18 changes: 11 additions & 7 deletions pdr_backend/contract/predictoor_contract.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import List, Tuple
from typing import Dict, List, Tuple
from unittest.mock import Mock

from enforce_typing import enforce_types
Expand All @@ -18,9 +18,12 @@
class PredictoorContract(BaseContract): # pylint: disable=too-many-public-methods
def __init__(self, web3_pp, address: str):
super().__init__(web3_pp, address, "ERC20Template3")
self.set_token(web3_pp)
self.last_allowance: Dict[str, int] = {}

def set_token(self, web3_pp):
stake_token = self.get_stake_token()
self.token = Token(web3_pp, stake_token)
self.last_allowance = 0

def is_valid_subscription(self):
"""Does this account have a subscription to this feed yet?"""
Expand Down Expand Up @@ -269,14 +272,15 @@ def submit_prediction(
stake_amt_wei = to_wei(stake_amt)

# Check allowance first, only approve if needed
if self.last_allowance <= 0:
self.last_allowance = self.token.allowance(
allowance = self.last_allowance.get(self.config.owner, 0)
if allowance <= 0:
self.last_allowance[self.config.owner] = self.token.allowance(
self.config.owner, self.contract_address
)
if self.last_allowance < stake_amt_wei:
if allowance < stake_amt_wei:
try:
self.token.approve(self.contract_address, MAX_UINT)
self.last_allowance = MAX_UINT
self.last_allowance[self.config.owner] = MAX_UINT
except Exception as e:
logger.error(
"Error while approving the contract to spend tokens: %s", e
Expand All @@ -296,7 +300,7 @@ def submit_prediction(
predicted_value, stake_amt_wei, prediction_ts
).transact(call_params)
txhash = tx.hex()
self.last_allowance -= stake_amt_wei
self.last_allowance[self.config.owner] -= stake_amt_wei
logger.info("Submitted prediction, txhash: %s", txhash)

if not wait_for_receipt:
Expand Down
8 changes: 8 additions & 0 deletions pdr_backend/ppss/web3_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import os
import random
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from unittest.mock import Mock
Expand Down Expand Up @@ -331,6 +332,9 @@ def get_current_epoch(self) -> int:
"""Returns an epoch number"""
return self.get_current_epoch_ts() // self.s_per_epoch

def set_token(self, web3_pp):
pass

def get_current_epoch_ts(self) -> UnixTimeS:
"""Returns a timestamp"""
return UnixTimeS(self._w3.eth.timestamp // self.s_per_epoch * self.s_per_epoch)
Expand Down Expand Up @@ -392,5 +396,9 @@ def advance_func(*args, **kwargs): # pylint: disable=unused-argument

assert hasattr(web3_pp.web3_config, "w3")
web3_pp.web3_config.w3 = mock_w3
copy_config = deepcopy(web3_pp.web3_config)
copy_config.owner = "0x3"
web3_pp.web3_config.copy_with_pk = Mock() # type: ignore
web3_pp.web3_config.copy_with_pk.return_value = copy_config

return _mock_pdr_contract
99 changes: 64 additions & 35 deletions pdr_backend/predictoor/predictoor_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,26 @@ class PredictoorAgent:
- Fetches Predictoor contracts from subgraph, and filters them
- Monitors each contract for epoch changes.
- When a value can be predicted, call calc_stakes()
Prediction is two-sided: it submits for both up and down directions,
with a stake for each.
- But: the contracts have a constraint: an account can only submit
*one* dir'n at an epoch. But we need to submit *both* dir'ns.
- Idea: redo smart contracts. Issue: significant work, especially rollout
- Idea: bot has *two* accounts: one for up, one for down. Yes this works**
OK. Assume two private keys are available. How should bot manage this?
- Idea: implement with a copy of the contract? (one for up, one for down)
- Via copy()? Issue: fails because it's too shallow, misses stuff
- Via deepcopy()? Issue: causes infinite recursion (py bug)
- Via deepcopy() with surgical changes? Issue: error prone
- Via query subgraph twice? Issue: many seconds slower -> annoying
- Via fill in whole contract again? Issue: tedious & error prone
- Idea: implement with a second Web3Config, and JIT switch on tx calls
- **Via 2nd constructor call? Yes, this works.** Easy because few params.
Summary of how to do two-sided predictions:
- two envvars --> two private keys -> two Web3Configs, JIT switch for txs
"""

@enforce_types
Expand All @@ -33,6 +53,19 @@ def __init__(self, ppss: PPSS):
self.ppss = ppss
logger.info(self.ppss)

# set web3_config_up/down (details in class docstring)
self.web3_config_up = self.ppss.web3_pp.web3_config

pk2 = os.getenv("PRIVATE_KEY2")
if pk2 is None:
raise ValueError("Need PRIVATE_KEY2 envvar")
if not hasattr(self.web3_config_up, "owner"):
raise ValueError("Need PRIVATE_KEY envvar")
self.web3_config_down = self.web3_config_up.copy_with_pk(pk2)

if self.web3_config_up.owner == self.web3_config_down.owner:
raise ValueError("private keys must differ")

# set self.feed
cand_feeds: Dict[str, SubgraphFeed] = ppss.web3_pp.query_feed_contracts()
print_feeds(cand_feeds, f"cand feeds, owner={ppss.web3_pp.owner_addrs}")
Expand All @@ -44,18 +77,11 @@ def __init__(self, ppss: PPSS):
print_feeds({feed.address: feed}, "filtered feed")
self.feed: SubgraphFeed = feed

# set self.feed_contract, self.feed_contract2
# set self.feed_contract. For both up/down. See submit_prediction_tx
self.feed_contract: PredictoorContract = ppss.web3_pp.get_single_contract(
feed.address
)

pk2: Optional[str] = os.getenv("PRIVATE_KEY2")
assert pk2 is not None, "Need PRIVATE_KEY2 envvar"
rpc_url: str = self.ppss.web3_pp.rpc_url
web3_config2 = Web3Config(rpc_url, pk2)
self.feed_contract2 = copy.deepcopy(self.feed_contract)
self.feed_contract2.web3_pp.set_web3_config(web3_config2)

# ensure ohlcv data cache is up to date
if self.use_ohlcv_data():
_ = self.get_ohlcv_data()
Expand Down Expand Up @@ -176,20 +202,10 @@ def submit_prediction_txs(
target_slot: UnixTimeS, # a timestamp
):
logger.info("Submit 'up' prediction tx to chain...")
tx1 = self.feed_contract.submit_prediction(
True,
stake_up,
target_slot,
wait_for_receipt=True,
)
tx1 = self.submit_1prediction_tx(True, stake_up, target_slot)

logger.info("Submit 'down' prediction tx to chain...")
tx2 = self.feed_contract2.submit_prediction(
False,
stake_down,
target_slot,
wait_for_receipt=True,
)
tx2 = self.submit_1prediction_tx(False, stake_down, target_slot)

# handle errors
if _tx_failed(tx1) or _tx_failed(tx2):
Expand All @@ -198,22 +214,35 @@ def submit_prediction_txs(
logger.warning(s)

logger.info("Re-submit 'up' prediction tx to chain... (stake=0)")
self.feed_contract.submit_prediction(
True,
1e-10,
target_slot,
wait_for_receipt=True,
)

self.submit_1prediction_tx(True, 1e-10, target_slot)
logger.info("Re-submit 'down' prediction tx to chain... (stake=0)")
self.feed_contract2.submit_prediction(
False,
1e-10,
target_slot,
wait_for_receipt=True,
)

return True
self.submit_1prediction_tx(False, 1e-10, target_slot)

@enforce_types
def submit_1prediction_tx(
self,
direction: bool,
stake: float, # in units of Eth
target_slot: UnixTimeS, # a timestamp
):
web3_config = self._updown_web3_config(direction)
self.feed_contract.web3_pp.set_web3_config(web3_config)
self.feed_contract.set_token(self.feed_contract.web3_pp)

tx = self.feed_contract.submit_prediction(
direction,
stake,
target_slot,
wait_for_receipt=True,
)
return tx

def _updown_web3_config(self, direction: bool) -> Web3Config:
"""Returns the web3_config corresponding to up vs down direction"""
if direction == True:
return self.web3_config_up
else:
return self.web3_config_down

@enforce_types
def calc_stakes(self) -> Tuple[float, float]:
Expand Down
3 changes: 3 additions & 0 deletions pdr_backend/util/web3_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ def __init__(self, rpc_url: str, private_key: Optional[str] = None):
)
self.w3.middleware_onion.add(http_retry_request_middleware)

def copy_with_pk(self, pk: str):
return Web3Config(self.rpc_url, pk)

def get_block(
self, block: BlockIdentifier, full_transactions: bool = False, tries: int = 0
) -> BlockData:
Expand Down
1 change: 1 addition & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ env =
D:RPC_URL=http://127.0.0.1:8545
D:SUBGRAPH_URL=http://172.15.0.15:8000/subgraphs/name/oceanprotocol/ocean-subgraph
D:PRIVATE_KEY=0xc594c6e5def4bab63ac29eed19a134c130388f74f019bc74b8f4389df2837a58
D:PRIVATE_KEY2=0xef4b441145c1d0f3b4bc6d61d29f5c6e502359481152f869247c7a4244d45209
1 change: 1 addition & 0 deletions system_tests/test_predictoor_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def setup_mock_web3_pp(mock_feeds, mock_predictoor_contract):

mock_web3_config = Mock(spec=Web3Config)
mock_web3_config.w3 = Mock()
mock_web3_config.owner = "0xowner"
mock_web3_config.get_block.return_value = {"timestamp": 100}
mock_web3_pp.web3_config = mock_web3_config

Expand Down

0 comments on commit 50fa66c

Please sign in to comment.