Skip to content

Commit

Permalink
feat: add advancer package
Browse files Browse the repository at this point in the history
  • Loading branch information
renan061 committed Jul 23, 2024
1 parent f749688 commit 915eb8e
Show file tree
Hide file tree
Showing 10 changed files with 1,055 additions and 6 deletions.
108 changes: 108 additions & 0 deletions internal/node/advancer/advancer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// (c) Cartesi and individual authors (see AUTHORS)
// SPDX-License-Identifier: Apache-2.0 (see LICENSE)

package advancer

import (
"context"
"errors"
"fmt"
"log/slog"
"time"

"github.com/cartesi/rollups-node/internal/node/advancer/service"
"github.com/cartesi/rollups-node/internal/node/machine/nodemachine"
. "github.com/cartesi/rollups-node/internal/node/model"
)

var (
ErrInvalidMachines = errors.New("machines must not be nil")
ErrInvalidRepository = errors.New("repository must not be nil")
ErrInvalidAddress = errors.New("no machine for address")
)

type Advancer struct {
machines Machines
repository Repository
}

func New(machines Machines, repository Repository) (*Advancer, error) {
if machines == nil {
return nil, ErrInvalidMachines
}
if repository == nil {
return nil, ErrInvalidRepository
}
return &Advancer{machines: machines, repository: repository}, nil
}

func (advancer *Advancer) Poller(pollingInterval time.Duration) (*service.Poller, error) {
return service.NewPoller("advancer", advancer, pollingInterval)
}

func (advancer *Advancer) Run(ctx context.Context) error {
appAddresses := keysToSlice(advancer.machines)

// Gets the unprocessed inputs (of all apps) from the repository.
slog.Info("advancer: getting unprocessed inputs")
inputs, err := advancer.repository.GetInputs(ctx, appAddresses)
if err != nil {
return err
}

// Processes each set of inputs.
for appAddress, inputs := range inputs {
slog.Info(fmt.Sprintf("advancer: processing %d input(s) from %v", len(inputs), appAddress))

machine, ok := advancer.machines[appAddress]
if !ok {
return fmt.Errorf("%w %s", ErrInvalidAddress, appAddress.String())
}

// Processes inputs from the same application sequentially.
for _, input := range inputs {
slog.Info("advancer: processing input", "id", input.Id)

res, err := machine.Advance(ctx, input.RawData)
if err != nil {
return err
}

err = advancer.repository.StoreResults(ctx, input, res)
if err != nil {
return err
}
}
}

return nil
}

// ------------------------------------------------------------------------------------------------

type Repository interface {
// Only needs Id and RawData fields from model.Input.
GetInputs(context.Context, []Address) (map[Address][]*Input, error)

StoreResults(context.Context, *Input, *nodemachine.AdvanceResult) error
}

// A map of application addresses to machines.
type Machines = map[Address]Machine

type Machine interface {
Advance(context.Context, []byte) (*nodemachine.AdvanceResult, error)
}

// ------------------------------------------------------------------------------------------------

// keysToSlice returns a slice with the keys of a map.
func keysToSlice[K comparable, V any](m map[K]V) []K {
keys := make([]K, len(m))
i := 0
for k := range m {
keys[i] = k
i++
}
return keys
}
171 changes: 171 additions & 0 deletions internal/node/advancer/advancer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
// (c) Cartesi and individual authors (see AUTHORS)
// SPDX-License-Identifier: Apache-2.0 (see LICENSE)

package advancer

import (
"context"
crand "crypto/rand"
mrand "math/rand"
"testing"

"github.com/cartesi/rollups-node/internal/node/machine/nodemachine"
. "github.com/cartesi/rollups-node/internal/node/model"

"github.com/stretchr/testify/suite"
)

func TestAdvancer(t *testing.T) {
suite.Run(t, new(AdvancerSuite))
}

type AdvancerSuite struct{ suite.Suite }

func (s *AdvancerSuite) TestNew() {
s.Run("Ok", func() {
require := s.Require()
machines := Machines{randomAddress(): &MockMachine{}}
repository := &MockRepository{}
advancer, err := New(machines, repository)
require.NotNil(advancer)
require.Nil(err)
})

s.Run("InvalidMachines", func() {
require := s.Require()
repository := &MockRepository{}
advancer, err := New(nil, repository)
require.Nil(advancer)
require.Equal(ErrInvalidMachines, err)
})

s.Run("InvalidRepository", func() {
require := s.Require()
machines := Machines{randomAddress(): &MockMachine{}}
advancer, err := New(machines, nil)
require.Nil(advancer)
require.Equal(ErrInvalidRepository, err)
})
}

// NOTE: this test is just the beginning; we need more tests.
func (s *AdvancerSuite) TestRun() {
require := s.Require()

appAddress := randomAddress()

machines := Machines{}
advanceRes := randomAdvanceResult()
machines[appAddress] = &MockMachine{AdvanceVal: advanceRes, AdvanceErr: nil}

repository := &MockRepository{
GetInputsVal: map[Address][]*Input{appAddress: randomInputs(1)},
GetInputsErr: nil,
StoreResultsErr: nil,
}

advancer, err := New(machines, repository)
require.NotNil(advancer)
require.Nil(err)

err = advancer.Run(context.Background())
require.Nil(err)

require.Len(repository.Stored, 1)
require.Equal(advanceRes, repository.Stored[0])
}

// ------------------------------------------------------------------------------------------------

type MockMachine struct {
AdvanceVal *nodemachine.AdvanceResult
AdvanceErr error
}

func (mock *MockMachine) Advance(_ context.Context, _ []byte) (*nodemachine.AdvanceResult, error) {
return mock.AdvanceVal, mock.AdvanceErr
}

// ------------------------------------------------------------------------------------------------

type MockRepository struct {
GetInputsVal map[Address][]*Input
GetInputsErr error
StoreResultsErr error

Stored []*nodemachine.AdvanceResult
}

func (mock *MockRepository) GetInputs(
_ context.Context,
appAddresses []Address,
) (map[Address][]*Input, error) {
return mock.GetInputsVal, mock.GetInputsErr
}

func (mock *MockRepository) StoreResults(
_ context.Context,
input *Input,
res *nodemachine.AdvanceResult,
) error {
mock.Stored = append(mock.Stored, res)
return mock.StoreResultsErr
}

// ------------------------------------------------------------------------------------------------

func randomAddress() Address {
address := make([]byte, 20)
_, err := crand.Read(address)
if err != nil {
panic(err)
}
return Address(address)
}

func randomHash() Hash {
hash := make([]byte, 32)
_, err := crand.Read(hash)
if err != nil {
panic(err)
}
return Hash(hash)
}

func randomBytes() []byte {
size := mrand.Intn(100) + 1
bytes := make([]byte, size)
_, err := crand.Read(bytes)
if err != nil {
panic(err)
}
return bytes
}

func randomSliceOfBytes() [][]byte {
size := mrand.Intn(10) + 1
slice := make([][]byte, size)
for i := 0; i < size; i++ {
slice[i] = randomBytes()
}
return slice
}

func randomInputs(size int) []*Input {
slice := make([]*Input, size)
for i := 0; i < size; i++ {
slice[i] = &Input{Id: uint64(i), RawData: randomBytes()}
}
return slice

}

func randomAdvanceResult() *nodemachine.AdvanceResult {
return &nodemachine.AdvanceResult{
Status: InputStatusAccepted,
Outputs: randomSliceOfBytes(),
Reports: randomSliceOfBytes(),
OutputsHash: randomHash(),
MachineHash: randomHash(),
}
}
64 changes: 64 additions & 0 deletions internal/node/advancer/service/service.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// (c) Cartesi and individual authors (see AUTHORS)
// SPDX-License-Identifier: Apache-2.0 (see LICENSE)

package service

import (
"context"
"errors"
"fmt"
"log/slog"
"sync/atomic"
"time"
)

type Service interface {
Run(context.Context) error
}

type Poller struct {
name string
service Service
shouldStop atomic.Bool
ticker *time.Ticker
}

var ErrInvalidPollingInterval = errors.New("polling interval must be greater than zero")

func NewPoller(name string, service Service, pollingInterval time.Duration) (*Poller, error) {
if pollingInterval <= 0 {
return nil, ErrInvalidPollingInterval
}
ticker := time.NewTicker(pollingInterval)
return &Poller{name: name, service: service, ticker: ticker}, nil
}

func (poller *Poller) Start(ctx context.Context, ready chan<- struct{}) error {
ready <- struct{}{}

slog.Info(fmt.Sprintf("%s: started", poller.name))

for {
// Runs the service's inner routine.
err := poller.service.Run(ctx)
if err != nil {
return err
}

// Checks if the service was ordered to stop.
if poller.shouldStop.Load() {
poller.shouldStop.Store(false)
slog.Info(fmt.Sprintf("%s: stopped", poller.name))
return nil
}

// Waits for the polling interval to elapse.
slog.Info(fmt.Sprintf("%s: waiting for the polling interval to elapse", poller.name))
<-poller.ticker.C
}
}

// Stop orders the service to stop, which will happen before the next poll.
func (poller *Poller) Stop() {
poller.shouldStop.Store(true)
}
Loading

0 comments on commit 915eb8e

Please sign in to comment.