Skip to content

Commit

Permalink
feat: state.{Get,Set}Extra[SA any](*StateDB,types.ExtraPayloads,...)
Browse files Browse the repository at this point in the history
  • Loading branch information
ARR4N committed Oct 4, 2024
1 parent 5ec080f commit 01bf6c2
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 0 deletions.
64 changes: 64 additions & 0 deletions core/state/state.libevm.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// Copyright 2024 the libevm authors.
//
// The libevm additions to go-ethereum are free software: you can redistribute
// them and/or modify them under the terms of the GNU Lesser General Public License
// as published by the Free Software Foundation, either version 3 of the License,
// or (at your option) any later version.
//
// The libevm additions are distributed in the hope that they will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser
// General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see
// <http://www.gnu.org/licenses/>.

package state

import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
)

// GetExtra returns the extra payload from the [types.StateAccount] associated
// with the address, or a zero-value `SA` if not found. The
// [types.ExtraPayloads] MUST be sourced from [types.RegisterExtras].
func GetExtra[SA any](s *StateDB, p types.ExtraPayloads[SA], addr common.Address) SA {
stateObject := s.getStateObject(addr)
if stateObject != nil {
return p.FromStateAccount(&stateObject.data)
}
var zero SA
return zero
}

// SetExtra sets the extra payload for the address. See [GetExtra] for details.
func SetExtra[SA any](s *StateDB, p types.ExtraPayloads[SA], addr common.Address, extra SA) {
stateObject := s.getOrNewStateObject(addr)
if stateObject != nil {
setExtraOnObject(stateObject, p, addr, extra)
}
}

func setExtraOnObject[SA any](s *stateObject, p types.ExtraPayloads[SA], addr common.Address, extra SA) {
s.db.journal.append(extraChange[SA]{
payloads: p,
account: &addr,
prev: p.FromStateAccount(&s.data),
})
p.SetOnStateAccount(&s.data, extra)
}

// extraChange is a [journalEntry] for [SetExtra] / [setExtraOnObject].
type extraChange[SA any] struct {
payloads types.ExtraPayloads[SA]
account *common.Address
prev SA
}

func (e extraChange[SA]) dirtied() *common.Address { return e.account }

func (e extraChange[SA]) revert(s *StateDB) {
e.payloads.SetOnStateAccount(&s.getStateObject(*e.account).data, e.prev)
}
130 changes: 130 additions & 0 deletions core/state/state.libevm_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
// Copyright 2024 the libevm authors.
//
// The libevm additions to go-ethereum are free software: you can redistribute
// them and/or modify them under the terms of the GNU Lesser General Public License
// as published by the Free Software Foundation, either version 3 of the License,
// or (at your option) any later version.
//
// The libevm additions are distributed in the hope that they will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser
// General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see
// <http://www.gnu.org/licenses/>.

package state_test

import (
"fmt"
"testing"

"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/state/snapshot"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/ethdb/memorydb"
"github.com/ethereum/go-ethereum/libevm/ethtest"
"github.com/ethereum/go-ethereum/triedb"
)

func TestGetSetExtra(t *testing.T) {
types.TestOnlyClearRegisteredExtras()
t.Cleanup(types.TestOnlyClearRegisteredExtras)
payloads := types.RegisterExtras[[]byte]()

rng := ethtest.NewPseudoRand(42)
addr := rng.Address()
nonce := rng.Uint64()
balance := rng.Uint256()
extra := rng.Bytes(8)

views := newWithSnaps(t)
stateDB := views.stateDB
assert.Nilf(t, state.GetExtra(stateDB, payloads, addr), "state.GetExtra() returns zero-value %T if before SetExtra()", extra)
stateDB.CreateAccount(addr)
stateDB.SetNonce(addr, nonce)
stateDB.SetBalance(addr, balance)
state.SetExtra(stateDB, payloads, addr, extra)

root, err := stateDB.Commit(1, false) // arbitrary block number
require.NoErrorf(t, err, "%T.Commit(1, false)", stateDB)
require.NotEqualf(t, types.EmptyRootHash, root, "root hash returned by %T.Commit() is not the empty root", stateDB)

t.Run(fmt.Sprintf("retrieve from %T", views.snaps), func(t *testing.T) {
iter, err := views.snaps.AccountIterator(root, common.Hash{})
require.NoErrorf(t, err, "%T.AccountIterator(...)", views.snaps)
defer iter.Release()

require.Truef(t, iter.Next(), "%T.Next() (i.e. at least one account)", iter)
require.NoErrorf(t, iter.Error(), "%T.Error()", iter)

t.Run("types.FullAccount()", func(t *testing.T) {
got, err := types.FullAccount(iter.Account())
require.NoErrorf(t, err, "types.FullAccount(%T.Account())", iter)

want := &types.StateAccount{
Nonce: nonce,
Balance: balance,
Root: types.EmptyRootHash,
CodeHash: types.EmptyCodeHash[:],
}
payloads.SetOnStateAccount(want, extra)

if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("types.FullAccount(%T.Account()) diff (-want +got):\n%s", iter, diff)
}
})

require.Falsef(t, iter.Next(), "%T.Next() after first account (i.e. only one)", iter)
})

t.Run(fmt.Sprintf("retrieve from new %T", views.stateDB), func(t *testing.T) {
stateDB, err := state.New(root, views.database, views.snaps)
require.NoError(t, err, "state.New()")

// triggers SlimAccount RLP decoding
assert.Equalf(t, nonce, stateDB.GetNonce(addr), "%T.GetNonce()", stateDB)
assert.Equalf(t, balance, stateDB.GetBalance(addr), "%T.GetBalance()", stateDB)
assert.Equal(t, extra, state.GetExtra(stateDB, payloads, addr), "state.GetExtra()")
})
}

// stateViews are different ways to access the same data.
type stateViews struct {
stateDB *state.StateDB
snaps *snapshot.Tree
database state.Database
}

func newWithSnaps(t *testing.T) stateViews {
t.Helper()
empty := types.EmptyRootHash
kvStore := memorydb.New()
ethDB := rawdb.NewDatabase(kvStore)
snaps, err := snapshot.New(
snapshot.Config{
CacheSize: 16, // Mb (arbitrary but non-zero)
},
kvStore,
triedb.NewDatabase(ethDB, nil),
empty,
)
require.NoError(t, err, "snapshot.New()")

database := state.NewDatabase(ethDB)
stateDB, err := state.New(empty, database, snaps)
require.NoError(t, err, "state.New()")

return stateViews{
stateDB: stateDB,
snaps: snaps,
database: database,
}
}

0 comments on commit 01bf6c2

Please sign in to comment.