Skip to content

Commit

Permalink
Merge branch 'main' into chia_dev_gh_test
Browse files Browse the repository at this point in the history
  • Loading branch information
altendky committed Nov 25, 2024
2 parents 95df856 + 0516e83 commit 4a702c9
Show file tree
Hide file tree
Showing 15 changed files with 263 additions and 144 deletions.
40 changes: 32 additions & 8 deletions chia/_tests/environments/wallet.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations

import contextlib
import json
import operator
import unittest
from dataclasses import asdict, dataclass, field
from typing import TYPE_CHECKING, ClassVar, Optional, Union, cast
from typing import TYPE_CHECKING, Any, ClassVar, Iterator, Union, cast

from chia._tests.environments.common import ServiceEnvironment
from chia.rpc.full_node_rpc_client import FullNodeRpcClient
Expand All @@ -15,7 +17,6 @@
from chia.simulator.full_node_simulator import FullNodeSimulator
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.ints import uint32
from chia.wallet.derivation_record import DerivationRecord
from chia.wallet.transaction_record import TransactionRecord
from chia.wallet.util.transaction_type import CLAWBACK_INCOMING_TRANSACTION_TYPES
from chia.wallet.util.tx_config import DEFAULT_TX_CONFIG, TXConfig
Expand Down Expand Up @@ -260,6 +261,22 @@ async def wait_for_transactions_to_settle(
return pending_txs


class NewPuzzleHashError(Exception):
pass


def catch_puzzle_hash_errors(func: Any) -> Any:
@contextlib.asynccontextmanager
async def catching_puzhash_errors(self: WalletStateManager, *args: Any, **kwargs: Any) -> Any:
try:
async with func(self, *args, **kwargs) as action_scope:
yield action_scope
except NewPuzzleHashError:
pass

return catching_puzhash_errors


@dataclass
class WalletTestFramework:
full_node: FullNodeSimulator
Expand All @@ -268,6 +285,15 @@ class WalletTestFramework:
environments: list[WalletEnvironment]
tx_config: TXConfig = DEFAULT_TX_CONFIG

@staticmethod
@contextlib.contextmanager
def new_puzzle_hashes_allowed() -> Iterator[None]:
with unittest.mock.patch(
"chia.wallet.wallet_state_manager.WalletStateManager.new_action_scope",
catch_puzzle_hash_errors(WalletStateManager.new_action_scope),
):
yield

async def process_pending_states(
self, state_transitions: list[WalletStateTransition], invalid_transactions: list[bytes32] = []
) -> None:
Expand All @@ -284,13 +310,11 @@ async def process_pending_states(
"""
# Take note of the number of puzzle hashes if we're supposed to be reusing
if self.tx_config.reuse_puzhash:
puzzle_hash_indexes: list[dict[uint32, Optional[DerivationRecord]]] = []
puzzle_hash_indexes: list[dict[uint32, int]] = []
for env in self.environments:
ph_indexes: dict[uint32, Optional[DerivationRecord]] = {}
ph_indexes: dict[uint32, int] = {}
for wallet_id in env.wallet_state_manager.wallets:
ph_indexes[
wallet_id
] = await env.wallet_state_manager.puzzle_store.get_current_derivation_record_for_wallet(wallet_id)
ph_indexes[wallet_id] = await env.wallet_state_manager.puzzle_store.get_unused_count(wallet_id)
puzzle_hash_indexes.append(ph_indexes)

pending_txs: list[list[TransactionRecord]] = []
Expand Down Expand Up @@ -359,5 +383,5 @@ async def process_pending_states(
for env, ph_indexes_before in zip(self.environments, puzzle_hash_indexes):
for wallet_id, ph_index in zip(env.wallet_state_manager.wallets, ph_indexes_before):
assert ph_indexes_before[wallet_id] == (
await env.wallet_state_manager.puzzle_store.get_current_derivation_record_for_wallet(wallet_id)
await env.wallet_state_manager.puzzle_store.get_unused_count(wallet_id)
)
164 changes: 96 additions & 68 deletions chia/_tests/wallet/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import contextlib
import unittest
from collections.abc import AsyncIterator, Awaitable
from contextlib import AsyncExitStack
from dataclasses import replace
Expand All @@ -14,7 +16,7 @@
run_block_generator2,
)

from chia._tests.environments.wallet import WalletEnvironment, WalletState, WalletTestFramework
from chia._tests.environments.wallet import NewPuzzleHashError, WalletEnvironment, WalletState, WalletTestFramework
from chia._tests.util.setup_nodes import setup_simulators_and_wallets_service
from chia._tests.wallet.wallet_block_tools import WalletBlockTools
from chia.consensus.constants import ConsensusConstants
Expand Down Expand Up @@ -144,6 +146,28 @@ def tx_config(request: Any) -> TXConfig:
return replace(DEFAULT_TX_CONFIG, reuse_puzhash=request.param)


def new_action_scope_wrapper(func: Any) -> Any:
@contextlib.asynccontextmanager
async def wrapped_new_action_scope(self: WalletStateManager, *args: Any, **kwargs: Any) -> Any:
# Take note of the number of puzzle hashes if we're supposed to be reusing
ph_indexes: dict[uint32, int] = {}
for wallet_id in self.wallets:
ph_indexes[wallet_id] = await self.puzzle_store.get_unused_count(wallet_id)

async with func(self, *args, **kwargs) as action_scope:
yield action_scope

# Finally, check that the number of puzzle hashes did or did not increase by the specified amount
if action_scope.config.tx_config.reuse_puzhash:
for wallet_id, ph_index in zip(self.wallets, ph_indexes):
if not ph_indexes[wallet_id] == (await self.puzzle_store.get_unused_count(wallet_id)):
raise NewPuzzleHashError(
f"wallet ID {wallet_id} generated new puzzle hashes while reuse_puzhash was False"
)

return wrapped_new_action_scope


# This fixture automatically creates 4 parametrized tests trusted/untrusted x reuse/new derivations
# These parameterizations can be skipped by manually specifying "trusted" or "reuse puzhash" to the fixture
@pytest.fixture(scope="function")
Expand Down Expand Up @@ -174,77 +198,81 @@ async def wallet_environments(

full_node[0]._api.full_node.config = {**full_node[0]._api.full_node.config, **config_overrides}

wallet_rpc_clients: list[WalletRpcClient] = []
async with AsyncExitStack() as astack:
for service in wallet_services:
service._node.config = {
**service._node.config,
"trusted_peers": (
{full_node[0]._api.server.node_id.hex(): full_node[0]._api.server.node_id.hex()}
if trusted_full_node
else {}
),
**config_overrides,
}
service._node.wallet_state_manager.config = service._node.config
# Shorten the 10 seconds default value
service._node.coin_state_retry_seconds = 2
await service._node.server.start_client(
PeerInfo(bt.config["self_hostname"], full_node[0]._api.full_node.server.get_port()), None
)
wallet_rpc_clients.append(
await astack.enter_async_context(
WalletRpcClient.create_as_context(
bt.config["self_hostname"],
# Semantics guarantee us a non-None value here
service.rpc_server.listen_port, # type: ignore[union-attr]
service.root_path,
service.config,
new_action_scope_wrapped = new_action_scope_wrapper(WalletStateManager.new_action_scope)
with unittest.mock.patch(
"chia.wallet.wallet_state_manager.WalletStateManager.new_action_scope", new=new_action_scope_wrapped
):
wallet_rpc_clients: list[WalletRpcClient] = []
async with AsyncExitStack() as astack:
for service in wallet_services:
service._node.config = {
**service._node.config,
"trusted_peers": (
{full_node[0]._api.server.node_id.hex(): full_node[0]._api.server.node_id.hex()}
if trusted_full_node
else {}
),
**config_overrides,
}
service._node.wallet_state_manager.config = service._node.config
# Shorten the 10 seconds default value
service._node.coin_state_retry_seconds = 2
await service._node.server.start_client(
PeerInfo(bt.config["self_hostname"], full_node[0]._api.full_node.server.get_port()), None
)
wallet_rpc_clients.append(
await astack.enter_async_context(
WalletRpcClient.create_as_context(
bt.config["self_hostname"],
# Semantics guarantee us a non-None value here
service.rpc_server.listen_port, # type: ignore[union-attr]
service.root_path,
service.config,
)
)
)
)

wallet_states: list[WalletState] = []
for service, blocks_needed in zip(wallet_services, request.param["blocks_needed"]):
if blocks_needed > 0:
await full_node[0]._api.farm_blocks_to_wallet(
count=blocks_needed, wallet=service._node.wallet_state_manager.main_wallet
wallet_states: list[WalletState] = []
for service, blocks_needed in zip(wallet_services, request.param["blocks_needed"]):
if blocks_needed > 0:
await full_node[0]._api.farm_blocks_to_wallet(
count=blocks_needed, wallet=service._node.wallet_state_manager.main_wallet
)
await full_node[0]._api.wait_for_wallet_synced(wallet_node=service._node, timeout=20)
wallet_states.append(
WalletState(
Balance(
confirmed_wallet_balance=uint128(2_000_000_000_000 * blocks_needed),
unconfirmed_wallet_balance=uint128(2_000_000_000_000 * blocks_needed),
spendable_balance=uint128(2_000_000_000_000 * blocks_needed),
pending_change=uint64(0),
max_send_amount=uint128(2_000_000_000_000 * blocks_needed),
unspent_coin_count=uint32(2 * blocks_needed),
pending_coin_removal_count=uint32(0),
),
)
)
await full_node[0]._api.wait_for_wallet_synced(wallet_node=service._node, timeout=20)
wallet_states.append(
WalletState(
Balance(
confirmed_wallet_balance=uint128(2_000_000_000_000 * blocks_needed),
unconfirmed_wallet_balance=uint128(2_000_000_000_000 * blocks_needed),
spendable_balance=uint128(2_000_000_000_000 * blocks_needed),
pending_change=uint64(0),
max_send_amount=uint128(2_000_000_000_000 * blocks_needed),
unspent_coin_count=uint32(2 * blocks_needed),
pending_coin_removal_count=uint32(0),
),

assert full_node[0].rpc_server is not None
client_node = await astack.enter_async_context(
FullNodeRpcClient.create_as_context(
bt.config["self_hostname"],
full_node[0].rpc_server.listen_port,
full_node[0].root_path,
full_node[0].config,
)
)

assert full_node[0].rpc_server is not None
client_node = await astack.enter_async_context(
FullNodeRpcClient.create_as_context(
bt.config["self_hostname"],
full_node[0].rpc_server.listen_port,
full_node[0].root_path,
full_node[0].config,
yield WalletTestFramework(
full_node[0]._api,
client_node,
trusted_full_node,
[
WalletEnvironment(
service=service,
rpc_client=rpc_client,
wallet_states={uint32(1): wallet_state},
)
for service, rpc_client, wallet_state in zip(wallet_services, wallet_rpc_clients, wallet_states)
],
tx_config,
)
)
yield WalletTestFramework(
full_node[0]._api,
client_node,
trusted_full_node,
[
WalletEnvironment(
service=service,
rpc_client=rpc_client,
wallet_states={uint32(1): wallet_state},
)
for service, rpc_client, wallet_state in zip(wallet_services, wallet_rpc_clients, wallet_states)
],
tx_config,
)
41 changes: 23 additions & 18 deletions chia/_tests/wallet/did_wallet/test_did.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,7 +843,7 @@ async def test_did_find_lost_did(self, wallet_environments: WalletTestFramework)
# Delete the coin and wallet
coin = await did_wallet.get_coin()
await wallet_node_0.wallet_state_manager.coin_store.delete_coin_record(coin.name())
await wallet_node_0.wallet_state_manager.user_store.delete_wallet(did_wallet.wallet_info.id)
await wallet_node_0.wallet_state_manager.delete_wallet(did_wallet.wallet_info.id)
wallet_node_0.wallet_state_manager.wallets.pop(did_wallet.wallet_info.id)
assert len(wallet_node_0.wallet_state_manager.wallets) == 1
# Find lost DID
Expand All @@ -857,6 +857,21 @@ async def test_did_find_lost_did(self, wallet_environments: WalletTestFramework)
)
)
did_wallet = wallet_node_0.wallet_state_manager.wallets[did_wallets[0].id]
env_0.wallet_aliases["did_found"] = did_wallets[0].id
await env_0.change_balances(
{
"did_found": {
"init": True,
"confirmed_wallet_balance": 101,
"unconfirmed_wallet_balance": 101,
"spendable_balance": 101,
"max_send_amount": 101,
"unspent_coin_count": 1,
}
}
)
await env_0.check_balances()

# Spend DID
recovery_list = [bytes32.fromhex(did_wallet.get_my_DID())]
await did_wallet.update_recovery_list(recovery_list, uint64(1))
Expand All @@ -866,31 +881,21 @@ async def test_did_find_lost_did(self, wallet_environments: WalletTestFramework)
) as action_scope:
await did_wallet.create_update_spend(action_scope)

env_0.wallet_aliases["did_found"] = 3

await wallet_environments.process_pending_states(
[
WalletStateTransition(
pre_block_balance_updates={
"did_found": {
"init": True,
"confirmed_wallet_balance": 101,
"unconfirmed_wallet_balance": 202, # Seems strange
"spendable_balance": 101,
"max_send_amount": 101,
"unspent_coin_count": 1,
"spendable_balance": -101,
"max_send_amount": -101,
"pending_change": 101,
"pending_coin_removal_count": 1,
"set_remainder": True,
},
},
post_block_balance_updates={
"did_found": {
"confirmed_wallet_balance": 0,
"unconfirmed_wallet_balance": -101,
"spendable_balance": 0,
"max_send_amount": 0,
"unspent_coin_count": 0,
"spendable_balance": 101,
"max_send_amount": 101,
"pending_change": -101,
"pending_coin_removal_count": -1,
},
Expand All @@ -902,7 +907,7 @@ async def test_did_find_lost_did(self, wallet_environments: WalletTestFramework)
# Delete the coin and change inner puzzle
coin = await did_wallet.get_coin()
await wallet_node_0.wallet_state_manager.coin_store.delete_coin_record(coin.name())
new_inner_puzzle = await did_wallet.get_new_did_innerpuz()
new_inner_puzzle = await did_wallet.get_did_innerpuz(new=True)
did_wallet.did_info = dataclasses.replace(did_wallet.did_info, current_inner=new_inner_puzzle)
# Recovery the coin
assert did_wallet.did_info.origin_coin is not None # mypy
Expand Down Expand Up @@ -1051,7 +1056,7 @@ async def test_did_attest_after_recovery(self, wallet_environments: WalletTestFr
backup_data,
)
env_0.wallet_aliases["did_2"] = 3
new_ph = await did_wallet_3.get_new_did_inner_hash()
new_ph = await did_wallet_3.get_did_inner_hash(new=True)
coin = await did_wallet_2.get_coin()
pubkey = (
await did_wallet_3.wallet_state_manager.get_unused_derivation_record(did_wallet_3.wallet_info.id)
Expand Down Expand Up @@ -1138,7 +1143,7 @@ async def test_did_attest_after_recovery(self, wallet_environments: WalletTestFr
)
env_1.wallet_aliases["did_2"] = 3
coin = await did_wallet.get_coin()
new_ph = await did_wallet_4.get_new_did_inner_hash()
new_ph = await did_wallet_4.get_did_inner_hash(new=True)
pubkey = (
await did_wallet_4.wallet_state_manager.get_unused_derivation_record(did_wallet_4.wallet_info.id)
).pubkey
Expand Down
Loading

0 comments on commit 4a702c9

Please sign in to comment.