Skip to content

Commit

Permalink
test: refactor staking precompile test suite (#792)
Browse files Browse the repository at this point in the history
  • Loading branch information
zakir-code authored Oct 30, 2024
1 parent 79ad4c0 commit 506ae27
Show file tree
Hide file tree
Showing 35 changed files with 1,132 additions and 2,007 deletions.
12 changes: 12 additions & 0 deletions contract/contract.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"math/big"
"strings"

sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
"github.com/ethereum/go-ethereum/accounts/abi"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/hexutil"
Expand Down Expand Up @@ -149,3 +150,14 @@ func PackBridgeCallCheckpoint(gravityID, methodName [32]byte, sender, refund com
eventNonce,
)
}

func unpackRetIsOk(abi abi.ABI, method string, res *evmtypes.MsgEthereumTxResponse) (*evmtypes.MsgEthereumTxResponse, error) {
var ret struct{ Value bool }
if err := abi.UnpackIntoInterface(&ret, method, res.Ret); err != nil {
return res, sdkerrors.ErrInvalidType.Wrapf("failed to unpack %s: %s", method, err.Error())
}
if !ret.Value {
return res, sdkerrors.ErrLogic.Wrapf("failed to execute %s", method)
}
return res, nil
}
18 changes: 3 additions & 15 deletions contract/erc20_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"math/big"

sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
"github.com/ethereum/go-ethereum/accounts/abi"
"github.com/ethereum/go-ethereum/common"
"github.com/evmos/ethermint/x/evm/types"
Expand Down Expand Up @@ -80,23 +79,12 @@ func (k ERC20TokenKeeper) Allowance(ctx context.Context, contractAddr, owner, sp
return allowanceRes.Value, nil
}

func (k ERC20TokenKeeper) unpackRet(method string, res *types.MsgEthereumTxResponse) (*types.MsgEthereumTxResponse, error) {
var result struct{ Value bool }
if err := k.abi.UnpackIntoInterface(&result, method, res.Ret); err != nil {
return res, sdkerrors.ErrInvalidType.Wrapf("failed to unpack transfer: %s", err.Error())
}
if !result.Value {
return res, sdkerrors.ErrLogic.Wrapf("failed to execute %s", method)
}
return res, nil
}

func (k ERC20TokenKeeper) Approve(ctx context.Context, contractAddr, from, spender common.Address, amount *big.Int) (*types.MsgEthereumTxResponse, error) {
res, err := k.ApplyContract(ctx, from, contractAddr, nil, k.abi, "approve", spender, amount)
if err != nil {
return nil, err
}
return k.unpackRet("approve", res)
return unpackRetIsOk(k.abi, "approve", res)
}

// PackMint only used for testing
Expand All @@ -117,15 +105,15 @@ func (k ERC20TokenKeeper) Transfer(ctx context.Context, contractAddr, from, rece
if err != nil {
return nil, err
}
return k.unpackRet("transfer", res)
return unpackRetIsOk(k.abi, "transfer", res)
}

func (k ERC20TokenKeeper) TransferFrom(ctx context.Context, contractAddr, from, sender, receiver common.Address, amount *big.Int) (*types.MsgEthereumTxResponse, error) {
res, err := k.ApplyContract(ctx, from, contractAddr, nil, k.abi, "transferFrom", sender, receiver, amount)
if err != nil {
return nil, err
}
return k.unpackRet("transferFrom", res)
return unpackRetIsOk(k.abi, "transferFrom", res)
}

func (k ERC20TokenKeeper) TransferOwnership(ctx context.Context, contractAddr, owner, newOwner common.Address) (*types.MsgEthereumTxResponse, error) {
Expand Down
94 changes: 28 additions & 66 deletions x/staking/types/contract.go → contract/staking.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package types
package contract

import (
"errors"
Expand Down Expand Up @@ -59,24 +59,6 @@ func (args *ApproveSharesArgs) GetValidator() sdk.ValAddress {
return valAddr
}

type DelegateArgs struct {
Validator string `abi:"_val"`
}

// Validate validates the args
func (args *DelegateArgs) Validate() error {
if _, err := sdk.ValAddressFromBech32(args.Validator); err != nil {
return fmt.Errorf("invalid validator address: %s", args.Validator)
}
return nil
}

// GetValidator returns the validator address, caller must ensure the validator address is valid
func (args *DelegateArgs) GetValidator() sdk.ValAddress {
valAddr, _ := sdk.ValAddressFromBech32(args.Validator)
return valAddr
}

type DelegateV2Args struct {
Validator string `abi:"_val"`
Amount *big.Int `abi:"_amount"`
Expand All @@ -93,6 +75,12 @@ func (args *DelegateV2Args) Validate() error {
return nil
}

// GetValidator returns the validator address, caller must ensure the validator address is valid
func (args *DelegateV2Args) GetValidator() sdk.ValAddress {
valAddr, _ := sdk.ValAddressFromBech32(args.Validator)
return valAddr
}

type DelegationArgs struct {
Validator string `abi:"_val"`
Delegator common.Address `abi:"_del"`
Expand Down Expand Up @@ -131,58 +119,38 @@ func (args *DelegationRewardsArgs) GetValidator() sdk.ValAddress {
return valAddr
}

type RedelegateArgs struct {
type RedelegateV2Args struct {
ValidatorSrc string `abi:"_valSrc"`
ValidatorDst string `abi:"_valDst"`
Shares *big.Int `abi:"_shares"`
Amount *big.Int `abi:"_amount"`
}

// Validate validates the args
func (args *RedelegateArgs) Validate() error {
func (args *RedelegateV2Args) Validate() error {
if _, err := sdk.ValAddressFromBech32(args.ValidatorSrc); err != nil {
return fmt.Errorf("invalid validator src address: %s", args.ValidatorSrc)
}
if _, err := sdk.ValAddressFromBech32(args.ValidatorDst); err != nil {
return fmt.Errorf("invalid validator dst address: %s", args.ValidatorDst)
}
if args.Shares == nil || args.Shares.Sign() <= 0 {
return errors.New("invalid shares")
if args.Amount == nil || args.Amount.Sign() <= 0 {
return errors.New("invalid amount")
}
return nil
}

// GetValidatorSrc returns the validator src address, caller must ensure the validator address is valid
func (args *RedelegateArgs) GetValidatorSrc() sdk.ValAddress {
func (args *RedelegateV2Args) GetValidatorSrc() sdk.ValAddress {
valAddr, _ := sdk.ValAddressFromBech32(args.ValidatorSrc)
return valAddr
}

// GetValidatorDst returns the validator dest address, caller must ensure the validator address is valid
func (args *RedelegateArgs) GetValidatorDst() sdk.ValAddress {
func (args *RedelegateV2Args) GetValidatorDst() sdk.ValAddress {
valAddr, _ := sdk.ValAddressFromBech32(args.ValidatorDst)
return valAddr
}

type RedelegateV2Args struct {
ValidatorSrc string `abi:"_valSrc"`
ValidatorDst string `abi:"_valDst"`
Amount *big.Int `abi:"_amount"`
}

// Validate validates the args
func (args *RedelegateV2Args) Validate() error {
if _, err := sdk.ValAddressFromBech32(args.ValidatorSrc); err != nil {
return fmt.Errorf("invalid validator src address: %s", args.ValidatorSrc)
}
if _, err := sdk.ValAddressFromBech32(args.ValidatorDst); err != nil {
return fmt.Errorf("invalid validator dst address: %s", args.ValidatorDst)
}
if args.Amount == nil || args.Amount.Sign() <= 0 {
return errors.New("invalid amount")
}
return nil
}

type TransferSharesArgs struct {
Validator string `abi:"_val"`
To common.Address `abi:"_to"`
Expand All @@ -206,6 +174,11 @@ func (args *TransferSharesArgs) GetValidator() sdk.ValAddress {
return valAddr
}

type TransferSharesRet struct {
Token *big.Int
Reward *big.Int
}

type TransferFromSharesArgs struct {
Validator string `abi:"_val"`
From common.Address `abi:"_from"`
Expand All @@ -230,26 +203,9 @@ func (args *TransferFromSharesArgs) GetValidator() sdk.ValAddress {
return valAddr
}

type UndelegateArgs struct {
Validator string `abi:"_val"`
Shares *big.Int `abi:"_shares"`
}

// Validate validates the args
func (args *UndelegateArgs) Validate() error {
if _, err := sdk.ValAddressFromBech32(args.Validator); err != nil {
return fmt.Errorf("invalid validator address: %s", args.Validator)
}
if args.Shares == nil || args.Shares.Sign() <= 0 {
return errors.New("invalid shares")
}
return nil
}

// GetValidator returns the validator address, caller must ensure the validator address is valid
func (args *UndelegateArgs) GetValidator() sdk.ValAddress {
valAddr, _ := sdk.ValAddressFromBech32(args.Validator)
return valAddr
type TransferFromSharesRet struct {
Token *big.Int
Reward *big.Int
}

type UndelegateV2Args struct {
Expand All @@ -268,6 +224,12 @@ func (args *UndelegateV2Args) Validate() error {
return nil
}

// GetValidator returns the validator address, caller must ensure the validator address is valid
func (args *UndelegateV2Args) GetValidator() sdk.ValAddress {
valAddr, _ := sdk.ValAddressFromBech32(args.Validator)
return valAddr
}

type WithdrawArgs struct {
Validator string `abi:"_val"`
}
Expand Down
157 changes: 157 additions & 0 deletions contract/staking_precompile.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
package contract

import (
"context"
"math/big"

sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
"github.com/ethereum/go-ethereum/accounts/abi"
"github.com/ethereum/go-ethereum/common"
evmtypes "github.com/evmos/ethermint/x/evm/types"
)

type StakingPrecompileKeeper struct {
Caller
abi abi.ABI
contractAddr common.Address
}

func NewStakingPrecompileKeeper(caller Caller, contractAddr common.Address) StakingPrecompileKeeper {
if IsZeroEthAddress(contractAddr) {
contractAddr = common.HexToAddress(StakingAddress)
}
return StakingPrecompileKeeper{
Caller: caller,
abi: MustABIJson(IStakingMetaData.ABI),
contractAddr: contractAddr,
}
}

func (k StakingPrecompileKeeper) WithContractAddr(c common.Address) StakingPrecompileKeeper {
stakingPrecompileKeeper := k
stakingPrecompileKeeper.contractAddr = c
return stakingPrecompileKeeper
}

func (k StakingPrecompileKeeper) AllowanceShares(ctx context.Context, args AllowanceSharesArgs) (*big.Int, error) {
var output struct {
Shares *big.Int
}
err := k.QueryContract(ctx, common.Address{}, k.contractAddr, k.abi, "allowanceShares", &output, args.Validator, args.Owner, args.Spender)
if err != nil {
return nil, err
}
return output.Shares, nil
}

func (k StakingPrecompileKeeper) Delegation(ctx context.Context, args DelegationArgs) (*big.Int, *big.Int, error) {
var output struct {
Shares *big.Int
DelegateAmount *big.Int
}
err := k.QueryContract(ctx, common.Address{}, k.contractAddr, k.abi, "delegation", &output, args.Validator, args.Delegator)
if err != nil {
return nil, nil, err
}
return output.Shares, output.DelegateAmount, nil
}

func (k StakingPrecompileKeeper) DelegationRewards(ctx context.Context, args DelegationRewardsArgs) (*big.Int, error) {
var output struct {
Rewards *big.Int
}
err := k.QueryContract(ctx, common.Address{}, k.contractAddr, k.abi, "delegationRewards", &output, args.Validator, args.Delegator)
if err != nil {
return nil, err
}
return output.Rewards, nil
}

func (k StakingPrecompileKeeper) ValidatorList(ctx context.Context, args ValidatorListArgs) ([]string, error) {
var valList []string
err := k.QueryContract(ctx, common.Address{}, k.contractAddr, k.abi, "validatorList", &valList, args.SortBy)
if err != nil {
return nil, err
}
return valList, nil
}

func (k StakingPrecompileKeeper) SlashingInfo(ctx context.Context, args SlashingInfoArgs) (bool, *big.Int, error) {
var output struct {
Jailed bool
Missed *big.Int
}
err := k.QueryContract(ctx, common.Address{}, k.contractAddr, k.abi, "slashingInfo", &output, args.Validator)
if err != nil {
return false, nil, err
}
return output.Jailed, output.Missed, nil
}

func (k StakingPrecompileKeeper) ApproveShares(ctx context.Context, from common.Address, args ApproveSharesArgs) (*evmtypes.MsgEthereumTxResponse, error) {
res, err := k.ApplyContract(ctx, from, k.contractAddr, nil, k.abi, "approveShares", args.Validator, args.Spender, args.Shares)
if err != nil {
return nil, err
}
return unpackRetIsOk(k.abi, "approveShares", res)
}

func (k StakingPrecompileKeeper) TransferShares(ctx context.Context, from common.Address, args TransferSharesArgs) (*evmtypes.MsgEthereumTxResponse, *TransferSharesRet, error) {
res, err := k.ApplyContract(ctx, from, k.contractAddr, nil, k.abi, "transferShares", args.Validator, args.To, args.Shares)
if err != nil {
return nil, nil, err
}
ret := new(TransferSharesRet)
if err = k.abi.UnpackIntoInterface(ret, "transferShares", res.Ret); err != nil {
return res, nil, sdkerrors.ErrInvalidType.Wrapf("failed to unpack transferShares: %s", err.Error())
}
return res, ret, nil
}

func (k StakingPrecompileKeeper) TransferFromShares(ctx context.Context, from common.Address, args TransferFromSharesArgs) (*evmtypes.MsgEthereumTxResponse, *TransferFromSharesRet, error) {
res, err := k.ApplyContract(ctx, from, k.contractAddr, nil, k.abi, "transferFromShares", args.Validator, args.From, args.To, args.Shares)
if err != nil {
return nil, nil, err
}
ret := new(TransferFromSharesRet)
if err = k.abi.UnpackIntoInterface(ret, "transferFromShares", res.Ret); err != nil {
return res, nil, sdkerrors.ErrInvalidType.Wrapf("failed to unpack transferFromShares: %s", err.Error())
}
return res, ret, nil
}

func (k StakingPrecompileKeeper) Withdraw(ctx context.Context, from common.Address, args WithdrawArgs) (*evmtypes.MsgEthereumTxResponse, *big.Int, error) {
res, err := k.ApplyContract(ctx, from, k.contractAddr, nil, k.abi, "withdraw", args.Validator)
if err != nil {
return nil, nil, err
}
ret := struct{ Reward *big.Int }{}
if err = k.abi.UnpackIntoInterface(&ret, "withdraw", res.Ret); err != nil {
return res, nil, sdkerrors.ErrInvalidType.Wrapf("failed to unpack withdraw: %s", err.Error())
}
return res, ret.Reward, nil
}

func (k StakingPrecompileKeeper) DelegateV2(ctx context.Context, from common.Address, args DelegateV2Args) (*evmtypes.MsgEthereumTxResponse, error) {
res, err := k.ApplyContract(ctx, from, k.contractAddr, nil, k.abi, "delegateV2", args.Validator, args.Amount)
if err != nil {
return nil, err
}
return unpackRetIsOk(k.abi, "delegateV2", res)
}

func (k StakingPrecompileKeeper) RedelegateV2(ctx context.Context, from common.Address, args RedelegateV2Args) (*evmtypes.MsgEthereumTxResponse, error) {
res, err := k.ApplyContract(ctx, from, k.contractAddr, nil, k.abi, "redelegateV2", args.ValidatorSrc, args.ValidatorDst, args.Amount)
if err != nil {
return nil, err
}
return unpackRetIsOk(k.abi, "redelegateV2", res)
}

func (k StakingPrecompileKeeper) UndelegateV2(ctx context.Context, from common.Address, args UndelegateV2Args) (*evmtypes.MsgEthereumTxResponse, error) {
res, err := k.ApplyContract(ctx, from, k.contractAddr, nil, k.abi, "undelegateV2", args.Validator, args.Amount)
if err != nil {
return nil, err
}
return unpackRetIsOk(k.abi, "undelegateV2", res)
}
Loading

0 comments on commit 506ae27

Please sign in to comment.