Skip to content

Commit

Permalink
Merge pull request onflow#6232 from onflow/ramtin/evm-refactor-precom…
Browse files Browse the repository at this point in the history
…piled-call-tracker

[Flow EVM] Refactoring precompiled contract call tracker
  • Loading branch information
ramtinms authored Jul 22, 2024
2 parents 8658ed9 + 4612727 commit f9bd3bf
Show file tree
Hide file tree
Showing 13 changed files with 401 additions and 160 deletions.
10 changes: 6 additions & 4 deletions fvm/evm/emulator/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ type Config struct {
TxContext *gethVM.TxContext
// base unit of gas for direct calls
DirectCallBaseGasUsage uint64
// list of precompiles
ExtraPrecompiles []types.PrecompiledContract
// captures extra precompiled calls
PCTracker *CallTracker
}

func (c *Config) ChainRules() gethParams.Rules {
Expand Down Expand Up @@ -100,6 +100,7 @@ func defaultConfig() *Config {
},
GetPrecompile: gethCore.GetPrecompile,
},
PCTracker: NewCallTracker(),
}
}

Expand Down Expand Up @@ -191,7 +192,9 @@ func WithExtraPrecompiledContracts(precompiledContracts []types.PrecompiledContr
return func(c *Config) *Config {
extraPreCompMap := make(map[gethCommon.Address]gethVM.PrecompiledContract)
for _, pc := range precompiledContracts {
extraPreCompMap[pc.Address().ToCommon()] = pc
// wrap pcs for tracking
wpc := c.PCTracker.RegisterPrecompiledContract(pc)
extraPreCompMap[pc.Address().ToCommon()] = wpc
}
c.BlockContext.GetPrecompile = func(rules gethParams.Rules, addr gethCommon.Address) (gethVM.PrecompiledContract, bool) {
prec, found := extraPreCompMap[addr]
Expand All @@ -200,7 +203,6 @@ func WithExtraPrecompiledContracts(precompiledContracts []types.PrecompiledContr
}
return gethCore.GetPrecompile(rules, addr)
}
c.ExtraPrecompiles = precompiledContracts
return c
}
}
Expand Down
42 changes: 7 additions & 35 deletions fvm/evm/emulator/emulator.go
Original file line number Diff line number Diff line change
Expand Up @@ -533,14 +533,14 @@ func (proc *procedure) runDirect(
txHash gethCommon.Hash,
txIndex uint,
) (*types.Result, error) {
// set the nonce for the message (needed for some opeartions like deployment)
// set the nonce for the message (needed for some operations like deployment)
msg.Nonce = proc.state.GetNonce(msg.From)
proc.evm.TxContext.Origin = msg.From
res, err := proc.run(msg, txHash, txIndex, types.DirectCallTxType)
if err != nil {
return nil, err
}
// all commmit errors (StateDB errors) has to be returned
// all commit errors (StateDB errors) has to be returned
return res, proc.commit(true)
}

Expand All @@ -561,8 +561,8 @@ func (proc *procedure) run(
TxHash: txHash,
}

// reset precompile tracking
proc.resetPrecompileTracking()
// reset precompile tracking in case
proc.config.PCTracker.Reset()
gasPool := (*gethCore.GasPool)(&proc.config.BlockContext.GasLimit)
execResult, err := gethCore.NewStateTransition(
proc.evm,
Expand All @@ -586,12 +586,9 @@ func (proc *procedure) run(
res.GasConsumed = execResult.UsedGas
res.GasRefund = proc.state.GetRefund()
res.Index = uint16(txIndex)

if proc.extraPrecompiledIsCalled() {
res.PrecompiledCalls, err = proc.capturePrecompiledCalls()
if err != nil {
return nil, err
}
res.PrecompiledCalls, err = proc.config.PCTracker.CapturedCalls()
if err != nil {
return nil, err
}
// we need to capture the returned value no matter the status
// if the tx is reverted the error message is returned as returned value
Expand All @@ -617,31 +614,6 @@ func (proc *procedure) run(
return &res, nil
}

func (proc *procedure) resetPrecompileTracking() {
for _, pc := range proc.config.ExtraPrecompiles {
pc.Reset()
}
}

func (proc *procedure) extraPrecompiledIsCalled() bool {
for _, pc := range proc.config.ExtraPrecompiles {
if pc.IsCalled() {
return true
}
}
return false
}

func (proc *procedure) capturePrecompiledCalls() ([]byte, error) {
apc := make(types.AggregatedPrecompiledCalls, 0)
for _, pc := range proc.config.ExtraPrecompiles {
if pc.IsCalled() {
apc = append(apc, *pc.CapturedCalls())
}
}
return apc.Encode()
}

func (proc *procedure) captureTraceBegin(
depth int,
typ gethVM.OpCode,
Expand Down
41 changes: 3 additions & 38 deletions fvm/evm/emulator/emulator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,6 @@ func TestCallingExtraPrecompiles(t *testing.T) {
input := []byte{1, 2}
output := []byte{3, 4}
addr := testutils.RandomAddress(t)
isCalled := false
capturedCall := &types.PrecompiledCalls{
Address: addr,
RequiredGasCalls: []types.RequiredGasCall{{
Expand All @@ -774,22 +773,12 @@ func TestCallingExtraPrecompiles(t *testing.T) {
return addr
},
RequiredGasFunc: func(input []byte) uint64 {
isCalled = true
return uint64(10)
},
RunFunc: func(inp []byte) ([]byte, error) {
isCalled = true
require.Equal(t, input, inp)
return output, nil
},
IsCalledFunc: func() bool {
return isCalled
},
CapturedCallsFunc: func() *types.PrecompiledCalls {
return capturedCall
},
ResetFunc: func() {
},
}

ctx := types.NewDefaultBlockContext(blockNumber.Uint64())
Expand Down Expand Up @@ -1057,12 +1046,9 @@ func TestTransactionTracing(t *testing.T) {
}

type MockedPrecompiled struct {
AddressFunc func() types.Address
RequiredGasFunc func(input []byte) uint64
RunFunc func(input []byte) ([]byte, error)
CapturedCallsFunc func() *types.PrecompiledCalls
ResetFunc func()
IsCalledFunc func() bool
AddressFunc func() types.Address
RequiredGasFunc func(input []byte) uint64
RunFunc func(input []byte) ([]byte, error)
}

var _ types.PrecompiledContract = &MockedPrecompiled{}
Expand All @@ -1081,30 +1067,9 @@ func (mp *MockedPrecompiled) RequiredGas(input []byte) uint64 {
return mp.RequiredGasFunc(input)
}

func (mp *MockedPrecompiled) IsCalled() bool {
if mp.IsCalledFunc == nil {
panic("IsCalled not set for the mocked precompiled contract")
}
return mp.IsCalledFunc()
}

func (mp *MockedPrecompiled) Run(input []byte) ([]byte, error) {
if mp.RunFunc == nil {
panic("Run not set for the mocked precompiled contract")
}
return mp.RunFunc(input)
}

func (mp *MockedPrecompiled) CapturedCalls() *types.PrecompiledCalls {
if mp.CapturedCallsFunc == nil {
panic("CapturedCalls not set for the mocked precompiled contract")
}
return mp.CapturedCallsFunc()
}

func (mp *MockedPrecompiled) Reset() {
if mp.ResetFunc == nil {
panic("Reset not set for the mocked precompiled contract")
}
mp.ResetFunc()
}
124 changes: 124 additions & 0 deletions fvm/evm/emulator/tracker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
package emulator

import (
"bytes"
"sort"

"github.com/onflow/flow-go/fvm/evm/types"
)

// CallTracker captures precompiled calls
type CallTracker struct {
callsByAddress map[types.Address]*types.PrecompiledCalls
}

// NewCallTracker constructs a new CallTracker
func NewCallTracker() *CallTracker {
return &CallTracker{}
}

// RegisterPrecompiledContract registers a precompiled contract for tracking
func (ct *CallTracker) RegisterPrecompiledContract(pc types.PrecompiledContract) types.PrecompiledContract {
return &WrappedPrecompiledContract{
pc: pc,
ct: ct,
}
}

// CaptureRequiredGas captures a required gas call
func (ct *CallTracker) CaptureRequiredGas(address types.Address, input []byte, output uint64) {
if ct.callsByAddress == nil {
ct.callsByAddress = make(map[types.Address]*types.PrecompiledCalls)
}
calls, found := ct.callsByAddress[address]
if !found {
calls = &types.PrecompiledCalls{
Address: address,
}
ct.callsByAddress[address] = calls
}

calls.RequiredGasCalls = append(calls.RequiredGasCalls, types.RequiredGasCall{
Input: input,
Output: output,
})
}

// CaptureRun captures a run calls
func (ct *CallTracker) CaptureRun(address types.Address, input []byte, output []byte, err error) {
if ct.callsByAddress == nil {
ct.callsByAddress = make(map[types.Address]*types.PrecompiledCalls)
}
calls, found := ct.callsByAddress[address]
if !found {
calls = &types.PrecompiledCalls{
Address: address,
}
ct.callsByAddress[address] = calls
}
errMsg := ""
if err != nil {
errMsg = err.Error()
}
calls.RunCalls = append(calls.RunCalls, types.RunCall{
Input: input,
Output: output,
ErrorMsg: errMsg,
})
}

// IsCalled returns true if any calls has been captured
func (ct *CallTracker) IsCalled() bool {
return len(ct.callsByAddress) != 0
}

// Encoded
func (ct *CallTracker) CapturedCalls() ([]byte, error) {
if !ct.IsCalled() {
return nil, nil
}
// else constructs an aggregated precompiled calls
apc := make(types.AggregatedPrecompiledCalls, 0)

sortedAddresses := make([]types.Address, 0, len(ct.callsByAddress))
// we need to sort by address to stay deterministic
for addr := range ct.callsByAddress {
sortedAddresses = append(sortedAddresses, addr)
}

sort.Slice(sortedAddresses,
func(i, j int) bool {
return bytes.Compare(sortedAddresses[i][:], sortedAddresses[j][:]) < 0
})

for _, addr := range sortedAddresses {
apc = append(apc, *ct.callsByAddress[addr])
}

return apc.Encode()
}

// Resets the tracker
func (ct *CallTracker) Reset() {
ct.callsByAddress = nil
}

type WrappedPrecompiledContract struct {
pc types.PrecompiledContract
ct *CallTracker
}

func (wpc *WrappedPrecompiledContract) Address() types.Address {
return wpc.pc.Address()
}
func (wpc *WrappedPrecompiledContract) RequiredGas(input []byte) uint64 {
output := wpc.pc.RequiredGas(input)
wpc.ct.CaptureRequiredGas(wpc.pc.Address(), input, output)
return output
}

func (wpc *WrappedPrecompiledContract) Run(input []byte) ([]byte, error) {
output, err := wpc.pc.Run(input)
wpc.ct.CaptureRun(wpc.pc.Address(), input, output, err)
return output, err
}
65 changes: 65 additions & 0 deletions fvm/evm/emulator/tracker_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package emulator_test

import (
"errors"
"testing"

"github.com/stretchr/testify/require"

"github.com/onflow/flow-go/fvm/evm/emulator"
"github.com/onflow/flow-go/fvm/evm/testutils"
"github.com/onflow/flow-go/fvm/evm/types"
)

func TestTracker(t *testing.T) {
apc := testutils.AggregatedPrecompiledCallsFixture(t)
var runCallCounter int
var requiredGasCallCounter int
pc := &MockedPrecompiled{
AddressFunc: func() types.Address {
return apc[0].Address
},
RequiredGasFunc: func(input []byte) uint64 {
res := apc[0].RequiredGasCalls[requiredGasCallCounter]
require.Equal(t, res.Input, input)
requiredGasCallCounter += 1
return res.Output
},
RunFunc: func(input []byte) ([]byte, error) {
res := apc[0].RunCalls[runCallCounter]
require.Equal(t, res.Input, input)
runCallCounter += 1
var err error
if len(res.ErrorMsg) > 0 {
err = errors.New(res.ErrorMsg)
}
return res.Output, err
},
}
tracker := emulator.NewCallTracker()
wpc := tracker.RegisterPrecompiledContract(pc)

require.Equal(t, apc[0].Address, wpc.Address())
for _, pc := range apc {
for _, call := range pc.RequiredGasCalls {
require.Equal(t, call.Output, wpc.RequiredGas(call.Input))
}
for _, call := range pc.RunCalls {
ret, err := wpc.Run(call.Input)
require.Equal(t, call.Output, ret)
errMsg := ""
if err != nil {
errMsg = err.Error()
}
require.Equal(t, call.ErrorMsg, errMsg)
}

}
require.True(t, tracker.IsCalled())

expectedEncoded, err := apc.Encode()
require.NoError(t, err)
encoded, err := tracker.CapturedCalls()
require.NoError(t, err)
require.Equal(t, expectedEncoded, encoded)
}
Loading

0 comments on commit f9bd3bf

Please sign in to comment.