Skip to content

Commit

Permalink
Make previousGlobalState a value instead of pointer
Browse files Browse the repository at this point in the history
There's no longer a need to be able to pass in nil.
  • Loading branch information
eljobe committed Nov 26, 2024
1 parent bd7bd35 commit b6ee9d2
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 38 deletions.
45 changes: 19 additions & 26 deletions staker/bold/bold_state_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func NewBOLDStateProvider(
func (s *BOLDStateProvider) ExecutionStateAfterPreviousState(
ctx context.Context,
maxInboxCount uint64,
previousGlobalState *protocol.GoGlobalState,
previousGlobalState protocol.GoGlobalState,
) (*protocol.ExecutionState, error) {
if maxInboxCount == 0 {
return nil, errors.New("max inbox count cannot be zero")
Expand All @@ -95,26 +95,24 @@ func (s *BOLDStateProvider) ExecutionStateAfterPreviousState(
}
return nil, err
}
if previousGlobalState != nil {
var previousMessageCount arbutil.MessageIndex
if previousGlobalState.Batch > 0 {
previousMessageCount, err = s.statelessValidator.InboxTracker().GetBatchMessageCount(previousGlobalState.Batch - 1)
if err != nil {
if strings.Contains(err.Error(), "not found") {
return nil, fmt.Errorf("%w: batch count %d", l2stateprovider.ErrChainCatchingUp, maxInboxCount)
}
return nil, err
var previousMessageCount arbutil.MessageIndex
if previousGlobalState.Batch > 0 {
previousMessageCount, err = s.statelessValidator.InboxTracker().GetBatchMessageCount(previousGlobalState.Batch - 1)
if err != nil {
if strings.Contains(err.Error(), "not found") {
return nil, fmt.Errorf("%w: batch count %d", l2stateprovider.ErrChainCatchingUp, maxInboxCount)
}
return nil, err
}
previousMessageCount += arbutil.MessageIndex(previousGlobalState.PosInBatch)
messageDiffBetweenBatches := messageCount - previousMessageCount
maxMessageCount := previousMessageCount + arbutil.MessageIndex(maxNumberOfBlocks)
if messageDiffBetweenBatches > maxMessageCount {
messageCount = maxMessageCount
batchIndex, _, err = s.statelessValidator.InboxTracker().FindInboxBatchContainingMessage(messageCount)
if err != nil {
return nil, err
}
}
previousMessageCount += arbutil.MessageIndex(previousGlobalState.PosInBatch)
messageDiffBetweenBatches := messageCount - previousMessageCount
maxMessageCount := previousMessageCount + arbutil.MessageIndex(maxNumberOfBlocks)
if messageDiffBetweenBatches > maxMessageCount {
messageCount = maxMessageCount
batchIndex, _, err = s.statelessValidator.InboxTracker().FindInboxBatchContainingMessage(messageCount)
if err != nil {
return nil, err
}
}
globalState, err := s.findGlobalStateFromMessageCountAndBatch(messageCount, l2stateprovider.Batch(batchIndex))
Expand All @@ -135,15 +133,10 @@ func (s *BOLDStateProvider) ExecutionStateAfterPreviousState(
GlobalState: protocol.GoGlobalState(globalState),
MachineStatus: protocol.MachineStatusFinished,
}

var previousGlobalStateOrDefault protocol.GoGlobalState
if previousGlobalState != nil {
previousGlobalStateOrDefault = *previousGlobalState
}
toBatch := executionState.GlobalState.Batch
historyCommitStates, _, err := s.StatesInBatchRange(
ctx,
previousGlobalStateOrDefault,
previousGlobalState,
toBatch,
l2stateprovider.Height(maxNumberOfBlocks),
)
Expand All @@ -155,7 +148,7 @@ func (s *BOLDStateProvider) ExecutionStateAfterPreviousState(
return nil, err
}
executionState.EndHistoryRoot = historyCommit.Merkle
fmt.Printf("ExecutionStateAfterPreviousState for previous state batch %v pos %v got end batch %v pos %v last leaf %v hash %v\n", previousGlobalStateOrDefault.Batch, previousGlobalStateOrDefault.PosInBatch, executionState.GlobalState.Batch, executionState.GlobalState.PosInBatch, historyCommitStates[len(historyCommitStates)-1], executionState.EndHistoryRoot)
fmt.Printf("ExecutionStateAfterPreviousState for previous state batch %v pos %v got end batch %v pos %v last leaf %v hash %v\n", previousGlobalState.Batch, previousGlobalState.PosInBatch, executionState.GlobalState.Batch, executionState.GlobalState.PosInBatch, historyCommitStates[len(historyCommitStates)-1], executionState.EndHistoryRoot)
return executionState, nil
}

Expand Down
4 changes: 2 additions & 2 deletions system_tests/bold_new_challenge_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,14 @@ type incorrectBlockStateProvider struct {
func (s *incorrectBlockStateProvider) ExecutionStateAfterPreviousState(
ctx context.Context,
maxInboxCount uint64,
previousGlobalState *protocol.GoGlobalState,
previousGlobalState protocol.GoGlobalState,
) (*protocol.ExecutionState, error) {
maxNumberOfBlocks := s.chain.SpecChallengeManager().LayerZeroHeights().BlockChallengeHeight.Uint64()
executionState, err := s.honest.ExecutionStateAfterPreviousState(ctx, maxInboxCount, previousGlobalState)
if err != nil {
return nil, err
}
evilStates, err := s.L2MessageStatesUpTo(ctx, *previousGlobalState, l2stateprovider.Batch(maxInboxCount), option.Some(l2stateprovider.Height(maxNumberOfBlocks)))
evilStates, err := s.L2MessageStatesUpTo(ctx, previousGlobalState, l2stateprovider.Batch(maxInboxCount), option.Some(l2stateprovider.Height(maxNumberOfBlocks)))
if err != nil {
return nil, err
}
Expand Down
18 changes: 9 additions & 9 deletions system_tests/bold_state_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ func TestChallengeProtocolBOLD_StateProvider(t *testing.T) {
_, err = stateManager.ExecutionStateAfterPreviousState(
ctx,
0,
&protocol.GoGlobalState{
protocol.GoGlobalState{
Batch: 0,
PosInBatch: 1,
},
Expand All @@ -254,7 +254,7 @@ func TestChallengeProtocolBOLD_StateProvider(t *testing.T) {
genesis, err := stateManager.ExecutionStateAfterPreviousState(
ctx,
1,
&protocol.GoGlobalState{
protocol.GoGlobalState{
Batch: 0,
PosInBatch: 0,
},
Expand All @@ -268,7 +268,7 @@ func TestChallengeProtocolBOLD_StateProvider(t *testing.T) {
first, err := stateManager.ExecutionStateAfterPreviousState(
ctx,
2,
&genesis.GlobalState,
genesis.GlobalState,
)
Require(t, err)
if first == nil {
Expand All @@ -279,7 +279,7 @@ func TestChallengeProtocolBOLD_StateProvider(t *testing.T) {
_, err = stateManager.ExecutionStateAfterPreviousState(
ctx,
10,
&first.GlobalState,
first.GlobalState,
)
if err == nil {
Fatal(t, "should not agree with execution state")
Expand All @@ -298,7 +298,7 @@ func TestChallengeProtocolBOLD_StateProvider(t *testing.T) {
SendRoot: result.SendRoot,
Batch: 3,
}
got, err := stateManager.ExecutionStateAfterPreviousState(ctx, 3, &first.GlobalState)
got, err := stateManager.ExecutionStateAfterPreviousState(ctx, 3, first.GlobalState)
Require(t, err)
if state.Batch != got.GlobalState.Batch {
Fatal(t, "wrong batch")
Expand All @@ -315,7 +315,7 @@ func TestChallengeProtocolBOLD_StateProvider(t *testing.T) {
_, err = stateManager.ExecutionStateAfterPreviousState(
ctx,
state.Batch+1,
&got.GlobalState,
got.GlobalState,
)
if err == nil {
Fatal(t, "should not agree with execution state")
Expand All @@ -325,17 +325,17 @@ func TestChallengeProtocolBOLD_StateProvider(t *testing.T) {
}
})
t.Run("ExecutionStateAfterBatchCount", func(t *testing.T) {
_, err = stateManager.ExecutionStateAfterPreviousState(ctx, 0, &protocol.GoGlobalState{})
_, err = stateManager.ExecutionStateAfterPreviousState(ctx, 0, protocol.GoGlobalState{})
if err == nil {
Fatal(t, "should have failed")
}
if !strings.Contains(err.Error(), "max inbox count cannot be zero") {
Fatal(t, "wrong error message", err)
}

genesis, err := stateManager.ExecutionStateAfterPreviousState(ctx, 1, &protocol.GoGlobalState{})
genesis, err := stateManager.ExecutionStateAfterPreviousState(ctx, 1, protocol.GoGlobalState{})
Require(t, err)
execState, err := stateManager.ExecutionStateAfterPreviousState(ctx, totalBatches, &genesis.GlobalState)
execState, err := stateManager.ExecutionStateAfterPreviousState(ctx, totalBatches, genesis.GlobalState)
Require(t, err)
if execState == nil {
Fatal(t, "should not be nil")
Expand Down

0 comments on commit b6ee9d2

Please sign in to comment.