-
Notifications
You must be signed in to change notification settings - Fork 66
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
1,055 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
// (c) Cartesi and individual authors (see AUTHORS) | ||
// SPDX-License-Identifier: Apache-2.0 (see LICENSE) | ||
|
||
package advancer | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"fmt" | ||
"log/slog" | ||
"time" | ||
|
||
"github.com/cartesi/rollups-node/internal/node/advancer/service" | ||
. "github.com/cartesi/rollups-node/internal/node/model" | ||
"github.com/cartesi/rollups-node/internal/node/nodemachine" | ||
) | ||
|
||
var ( | ||
ErrInvalidMachines = errors.New("machines must not be nil") | ||
ErrInvalidRepository = errors.New("repository must not be nil") | ||
ErrInvalidAddress = errors.New("no machine for address") | ||
) | ||
|
||
type Advancer struct { | ||
machines Machines | ||
repository Repository | ||
} | ||
|
||
func New(machines Machines, repository Repository) (*Advancer, error) { | ||
if machines == nil { | ||
return nil, ErrInvalidMachines | ||
} | ||
if repository == nil { | ||
return nil, ErrInvalidRepository | ||
} | ||
return &Advancer{machines: machines, repository: repository}, nil | ||
} | ||
|
||
func (advancer *Advancer) Poller(pollingInterval time.Duration) (*service.Poller, error) { | ||
return service.NewPoller("advancer", advancer, pollingInterval) | ||
} | ||
|
||
func (advancer *Advancer) Run(ctx context.Context) error { | ||
appAddresses := keysToSlice(advancer.machines) | ||
|
||
// Gets the unprocessed inputs (of all apps) from the repository. | ||
slog.Info("advancer: getting unprocessed inputs") | ||
inputs, err := advancer.repository.GetInputs(ctx, appAddresses) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
// Processes each set of inputs. | ||
for appAddress, inputs := range inputs { | ||
slog.Info(fmt.Sprintf("advancer: processing %d input(s) from %v", len(inputs), appAddress)) | ||
|
||
machine, ok := advancer.machines[appAddress] | ||
if !ok { | ||
return fmt.Errorf("%w %s", ErrInvalidAddress, appAddress.String()) | ||
} | ||
|
||
// Processes inputs from the same application sequentially. | ||
for _, input := range inputs { | ||
slog.Info("advancer: processing input", "id", input.Id) | ||
|
||
res, err := machine.Advance(ctx, input.RawData) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
err = advancer.repository.StoreResults(ctx, input, res) | ||
if err != nil { | ||
return err | ||
} | ||
} | ||
} | ||
|
||
return nil | ||
} | ||
|
||
// ------------------------------------------------------------------------------------------------ | ||
|
||
type Repository interface { | ||
// Only needs Id and RawData fields from model.Input. | ||
GetInputs(context.Context, []Address) (map[Address][]*Input, error) | ||
|
||
StoreResults(context.Context, *Input, *nodemachine.AdvanceResult) error | ||
} | ||
|
||
// A map of application addresses to machines. | ||
type Machines = map[Address]Machine | ||
|
||
type Machine interface { | ||
Advance(context.Context, []byte) (*nodemachine.AdvanceResult, error) | ||
} | ||
|
||
// ------------------------------------------------------------------------------------------------ | ||
|
||
// keysToSlice returns a slice with the keys of a map. | ||
func keysToSlice[K comparable, V any](m map[K]V) []K { | ||
keys := make([]K, len(m)) | ||
i := 0 | ||
for k := range m { | ||
keys[i] = k | ||
i++ | ||
} | ||
return keys | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
// (c) Cartesi and individual authors (see AUTHORS) | ||
// SPDX-License-Identifier: Apache-2.0 (see LICENSE) | ||
|
||
package advancer | ||
|
||
import ( | ||
"context" | ||
crand "crypto/rand" | ||
mrand "math/rand" | ||
"testing" | ||
|
||
. "github.com/cartesi/rollups-node/internal/node/model" | ||
"github.com/cartesi/rollups-node/internal/node/nodemachine" | ||
|
||
"github.com/stretchr/testify/suite" | ||
) | ||
|
||
func TestAdvancer(t *testing.T) { | ||
suite.Run(t, new(AdvancerSuite)) | ||
} | ||
|
||
type AdvancerSuite struct{ suite.Suite } | ||
|
||
func (s *AdvancerSuite) TestNew() { | ||
s.Run("Ok", func() { | ||
require := s.Require() | ||
machines := Machines{randomAddress(): &MockMachine{}} | ||
repository := &MockRepository{} | ||
advancer, err := New(machines, repository) | ||
require.NotNil(advancer) | ||
require.Nil(err) | ||
}) | ||
|
||
s.Run("InvalidMachines", func() { | ||
require := s.Require() | ||
repository := &MockRepository{} | ||
advancer, err := New(nil, repository) | ||
require.Nil(advancer) | ||
require.Equal(ErrInvalidMachines, err) | ||
}) | ||
|
||
s.Run("InvalidRepository", func() { | ||
require := s.Require() | ||
machines := Machines{randomAddress(): &MockMachine{}} | ||
advancer, err := New(machines, nil) | ||
require.Nil(advancer) | ||
require.Equal(ErrInvalidRepository, err) | ||
}) | ||
} | ||
|
||
// NOTE: this test is just the beginning; we need more tests. | ||
func (s *AdvancerSuite) TestRun() { | ||
require := s.Require() | ||
|
||
appAddress := randomAddress() | ||
|
||
machines := Machines{} | ||
advanceRes := randomAdvanceResult() | ||
machines[appAddress] = &MockMachine{AdvanceVal: advanceRes, AdvanceErr: nil} | ||
|
||
repository := &MockRepository{ | ||
GetInputsVal: map[Address][]*Input{appAddress: randomInputs(1)}, | ||
GetInputsErr: nil, | ||
StoreResultsErr: nil, | ||
} | ||
|
||
advancer, err := New(machines, repository) | ||
require.NotNil(advancer) | ||
require.Nil(err) | ||
|
||
err = advancer.Run(context.Background()) | ||
require.Nil(err) | ||
|
||
require.Len(repository.Stored, 1) | ||
require.Equal(advanceRes, repository.Stored[0]) | ||
} | ||
|
||
// ------------------------------------------------------------------------------------------------ | ||
|
||
type MockMachine struct { | ||
AdvanceVal *nodemachine.AdvanceResult | ||
AdvanceErr error | ||
} | ||
|
||
func (mock *MockMachine) Advance(_ context.Context, _ []byte) (*nodemachine.AdvanceResult, error) { | ||
return mock.AdvanceVal, mock.AdvanceErr | ||
} | ||
|
||
// ------------------------------------------------------------------------------------------------ | ||
|
||
type MockRepository struct { | ||
GetInputsVal map[Address][]*Input | ||
GetInputsErr error | ||
StoreResultsErr error | ||
|
||
Stored []*nodemachine.AdvanceResult | ||
} | ||
|
||
func (mock *MockRepository) GetInputs( | ||
_ context.Context, | ||
appAddresses []Address, | ||
) (map[Address][]*Input, error) { | ||
return mock.GetInputsVal, mock.GetInputsErr | ||
} | ||
|
||
func (mock *MockRepository) StoreResults( | ||
_ context.Context, | ||
input *Input, | ||
res *nodemachine.AdvanceResult, | ||
) error { | ||
mock.Stored = append(mock.Stored, res) | ||
return mock.StoreResultsErr | ||
} | ||
|
||
// ------------------------------------------------------------------------------------------------ | ||
|
||
func randomAddress() Address { | ||
address := make([]byte, 20) | ||
_, err := crand.Read(address) | ||
if err != nil { | ||
panic(err) | ||
} | ||
return Address(address) | ||
} | ||
|
||
func randomHash() Hash { | ||
hash := make([]byte, 32) | ||
_, err := crand.Read(hash) | ||
if err != nil { | ||
panic(err) | ||
} | ||
return Hash(hash) | ||
} | ||
|
||
func randomBytes() []byte { | ||
size := mrand.Intn(100) + 1 | ||
bytes := make([]byte, size) | ||
_, err := crand.Read(bytes) | ||
if err != nil { | ||
panic(err) | ||
} | ||
return bytes | ||
} | ||
|
||
func randomSliceOfBytes() [][]byte { | ||
size := mrand.Intn(10) + 1 | ||
slice := make([][]byte, size) | ||
for i := 0; i < size; i++ { | ||
slice[i] = randomBytes() | ||
} | ||
return slice | ||
} | ||
|
||
func randomInputs(size int) []*Input { | ||
slice := make([]*Input, size) | ||
for i := 0; i < size; i++ { | ||
slice[i] = &Input{Id: uint64(i), RawData: randomBytes()} | ||
} | ||
return slice | ||
|
||
} | ||
|
||
func randomAdvanceResult() *nodemachine.AdvanceResult { | ||
return &nodemachine.AdvanceResult{ | ||
Status: InputStatusAccepted, | ||
Outputs: randomSliceOfBytes(), | ||
Reports: randomSliceOfBytes(), | ||
OutputsHash: randomHash(), | ||
MachineHash: randomHash(), | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
// (c) Cartesi and individual authors (see AUTHORS) | ||
// SPDX-License-Identifier: Apache-2.0 (see LICENSE) | ||
|
||
package service | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"fmt" | ||
"log/slog" | ||
"sync/atomic" | ||
"time" | ||
) | ||
|
||
type Service interface { | ||
Run(context.Context) error | ||
} | ||
|
||
type Poller struct { | ||
name string | ||
service Service | ||
shouldStop atomic.Bool | ||
ticker *time.Ticker | ||
} | ||
|
||
var ErrInvalidPollingInterval = errors.New("polling interval must be greater than zero") | ||
|
||
func NewPoller(name string, service Service, pollingInterval time.Duration) (*Poller, error) { | ||
if pollingInterval <= 0 { | ||
return nil, ErrInvalidPollingInterval | ||
} | ||
ticker := time.NewTicker(pollingInterval) | ||
return &Poller{name: name, service: service, ticker: ticker}, nil | ||
} | ||
|
||
func (poller *Poller) Start(ctx context.Context, ready chan<- struct{}) error { | ||
ready <- struct{}{} | ||
|
||
slog.Info(fmt.Sprintf("%s: started", poller.name)) | ||
|
||
for { | ||
// Runs the service's inner routine. | ||
err := poller.service.Run(ctx) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
// Checks if the service was ordered to stop. | ||
if poller.shouldStop.Load() { | ||
poller.shouldStop.Store(false) | ||
slog.Info(fmt.Sprintf("%s: stopped", poller.name)) | ||
return nil | ||
} | ||
|
||
// Waits for the polling interval to elapse. | ||
slog.Info(fmt.Sprintf("%s: waiting for the polling interval to elapse", poller.name)) | ||
<-poller.ticker.C | ||
} | ||
} | ||
|
||
// Stop orders the service to stop, which will happen before the next poll. | ||
func (poller *Poller) Stop() { | ||
poller.shouldStop.Store(true) | ||
} |
Oops, something went wrong.