diff --git a/internal/node/machineadvancer/advancer.go b/internal/node/machineadvancer/advancer.go new file mode 100644 index 000000000..4c301b6d3 --- /dev/null +++ b/internal/node/machineadvancer/advancer.go @@ -0,0 +1,99 @@ +// (c) Cartesi and individual authors (see AUTHORS) +// SPDX-License-Identifier: Apache-2.0 (see LICENSE) + +package machineadvancer + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/cartesi/rollups-node/internal/node/model" + "github.com/cartesi/rollups-node/internal/node/nodemachine" +) + +type Machine interface { + Advance(_ context.Context, input []byte) (*nodemachine.AdvanceResponse, error) +} + +type MachineAdvancer struct { + machines map[model.Address]Machine + repository Repository + ticker *time.Ticker +} + +var ( + ErrInvalidMachines = errors.New("must have at least one machine") + ErrInvalidRepository = errors.New("repository must not be nil") + ErrInvalidPollingInterval = errors.New("polling interval must be greater than zero") + + ErrInvalidAddress = errors.New("invalid address from repository") +) + +// Duration must be greater than 0. +func New( + machines map[model.Address]Machine, + repository Repository, + pollingInterval time.Duration, +) (*MachineAdvancer, error) { + if len(machines) <= 0 { + return nil, ErrInvalidMachines + } + if repository == nil { + return nil, ErrInvalidRepository + } + if pollingInterval <= 0 { + return nil, ErrInvalidPollingInterval + } + return &MachineAdvancer{ + machines: machines, + repository: repository, + ticker: time.NewTicker(pollingInterval), + }, nil +} + +func (advancer *MachineAdvancer) Start() error { + addresses := keysToSlice(advancer.machines) + for { + // Gets the unprocessed inputs (of all apps) from the repository. + inputs, err := advancer.repository.GetInputs(addresses) + if err != nil { + return err + } + + for appAddress, inputs := range inputs { + machine, ok := advancer.machines[appAddress] + if !ok { + return fmt.Errorf("%w: %s", ErrInvalidAddress, appAddress.String()) + } + + // Processes all inputs sequentially. + for _, input := range inputs { + res, err := machine.Advance(context.Background(), input.RawData) + if err != nil { + return err + } + + err = advancer.repository.Store(input, res) + if err != nil { + return err + } + } + } + + // Waits for the current polling interval to elapse. + <-advancer.ticker.C + } +} + +// keysToSlice returns a slice with the keys of a map. +func keysToSlice[T comparable, U any](m map[T]U) []T { + keys := make([]T, len(m)) + i := 0 + for key := range m { + keys[i] = key + i++ + } + return keys +} diff --git a/internal/node/machineadvancer/advancer_test.go b/internal/node/machineadvancer/advancer_test.go new file mode 100644 index 000000000..970e4a433 --- /dev/null +++ b/internal/node/machineadvancer/advancer_test.go @@ -0,0 +1,242 @@ +// (c) Cartesi and individual authors (see AUTHORS) +// SPDX-License-Identifier: Apache-2.0 (see LICENSE) + +package machineadvancer + +import ( + "context" + crand "crypto/rand" + "errors" + mrand "math/rand" + "testing" + "time" + + "github.com/cartesi/rollups-node/internal/node/model" + "github.com/cartesi/rollups-node/internal/node/nodemachine" + + "github.com/stretchr/testify/suite" +) + +func TestMachineAdvancer(t *testing.T) { + suite.Run(t, new(MachineAdvancerSuite)) +} + +type MachineAdvancerSuite struct{ suite.Suite } + +func (s *MachineAdvancerSuite) TestNew() { + s.Run("Ok", func() { + require := s.Require() + machines := map[model.Address]Machine{randomAddress(): newMockMachine()} + repository := newMockRepository() + machineAdvancer, err := New(machines, repository, time.Nanosecond) + require.NotNil(machineAdvancer) + require.Nil(err) + }) + + s.Run("InvalidMachines", func() { + require := s.Require() + repository := newMockRepository() + machineAdvancer, err := New(nil, repository, time.Nanosecond) + require.Nil(machineAdvancer) + require.Equal(ErrInvalidMachines, err) + }) + + s.Run("InvalidRepository", func() { + require := s.Require() + machines := map[model.Address]Machine{randomAddress(): newMockMachine()} + machineAdvancer, err := New(machines, nil, time.Nanosecond) + require.Nil(machineAdvancer) + require.Equal(ErrInvalidRepository, err) + }) + + s.Run("InvalidPollingInterval", func() { + require := s.Require() + machines := map[model.Address]Machine{randomAddress(): newMockMachine()} + repository := newMockRepository() + machineAdvancer, err := New(machines, repository, time.Duration(0)) + require.Nil(machineAdvancer) + require.Equal(ErrInvalidPollingInterval, err) + }) +} + +func (s *MachineAdvancerSuite) TestStart() { + suite.Run(s.T(), new(StartSuite)) +} + +// ------------------------------------------------------------------------------------------------ + +type StartSuite struct { + suite.Suite + machines map[model.Address]Machine + repository *MockRepository +} + +func (s *StartSuite) SetupTest() { + s.machines = map[model.Address]Machine{} + s.repository = newMockRepository() +} + +// NOTE: This test is very basic! We need more tests! +func (s *StartSuite) TestBasic() { + require := s.Require() + + appAddress := randomAddress() + + machine := newMockMachine() + advanceResponse := randomAdvanceResponse() + machine.add(advanceResponse, nil) + s.machines[appAddress] = machine + + s.repository.add(map[model.Address][]model.Input{appAddress: randomInputs(1)}, nil, nil) + + machineAdvancer, err := New(s.machines, s.repository, time.Nanosecond) + require.NotNil(machineAdvancer) + require.Nil(err) + + err = machineAdvancer.Start() + require.Equal(testFinished, err) + + require.Len(s.repository.stored, 1) + require.Equal(advanceResponse, s.repository.stored[0]) +} + +// ------------------------------------------------------------------------------------------------ + +type MockMachine struct { + index uint8 + results []*nodemachine.AdvanceResponse + errors []error +} + +func newMockMachine() *MockMachine { + return &MockMachine{ + index: 0, + results: []*nodemachine.AdvanceResponse{}, + errors: []error{}, + } +} + +func (m *MockMachine) add(result *nodemachine.AdvanceResponse, err error) { + m.results = append(m.results, result) + m.errors = append(m.errors, err) +} + +func (m *MockMachine) Advance( + _ context.Context, + input []byte, +) (*nodemachine.AdvanceResponse, error) { + result, err := m.results[m.index], m.errors[m.index] + m.index += 1 + return result, err +} + +// ------------------------------------------------------------------------------------------------ + +type MockRepository struct { + getInputsIndex uint8 + getInputsResults []map[model.Address][]model.Input + getInputsErrors []error + + storeIndex uint8 + storeErrors []error + stored []*nodemachine.AdvanceResponse +} + +func newMockRepository() *MockRepository { + return &MockRepository{ + getInputsIndex: 0, + getInputsResults: []map[model.Address][]model.Input{}, + getInputsErrors: []error{}, + storeIndex: 0, + storeErrors: []error{}, + stored: []*nodemachine.AdvanceResponse{}, + } +} + +func (r *MockRepository) add( + getInputsResult map[model.Address][]model.Input, + getInputsError error, + storeError error, +) { + r.getInputsResults = append(r.getInputsResults, getInputsResult) + r.getInputsErrors = append(r.getInputsErrors, getInputsError) + r.storeErrors = append(r.storeErrors, storeError) +} + +var testFinished = errors.New("test finished") + +func (r *MockRepository) GetInputs( + appAddresses []model.Address, +) (map[model.Address][]model.Input, error) { + if int(r.getInputsIndex) == len(r.getInputsResults) { + return nil, testFinished + } + result, err := r.getInputsResults[r.getInputsIndex], r.getInputsErrors[r.getInputsIndex] + r.getInputsIndex += 1 + return result, err +} + +func (r *MockRepository) Store(input model.Input, res *nodemachine.AdvanceResponse) error { + err := r.storeErrors[r.storeIndex] + r.storeIndex += 1 + r.stored = append(r.stored, res) + return err +} + +// ------------------------------------------------------------------------------------------------ + +func randomAddress() model.Address { + address := make([]byte, 20) + _, err := crand.Read(address) + if err != nil { + panic(err) + } + return model.Address(address) +} + +func randomHash() model.Hash { + hash := make([]byte, 32) + _, err := crand.Read(hash) + if err != nil { + panic(err) + } + return model.Hash(hash) +} + +func randomBytes() []byte { + size := mrand.Intn(100) + 1 + bytes := make([]byte, size) + _, err := crand.Read(bytes) + if err != nil { + panic(err) + } + return bytes +} + +func randomSliceOfBytes() [][]byte { + size := mrand.Intn(10) + 1 + slice := make([][]byte, size) + for i := 0; i < size; i++ { + slice[i] = randomBytes() + } + return slice +} + +func randomInputs(size int) []model.Input { + slice := make([]model.Input, size) + for i := 0; i < size; i++ { + slice[i] = model.Input{Id: uint64(i), RawData: randomBytes()} + } + return slice + +} + +func randomAdvanceResponse() *nodemachine.AdvanceResponse { + return &nodemachine.AdvanceResponse{ + Status: model.InputStatusAccepted, + Outputs: randomSliceOfBytes(), + Reports: randomSliceOfBytes(), + OutputsHash: randomHash(), + MachineHash: randomHash(), + } +} diff --git a/internal/node/machineadvancer/repository.go b/internal/node/machineadvancer/repository.go new file mode 100644 index 000000000..252f1061c --- /dev/null +++ b/internal/node/machineadvancer/repository.go @@ -0,0 +1,16 @@ +// (c) Cartesi and individual authors (see AUTHORS) +// SPDX-License-Identifier: Apache-2.0 (see LICENSE) + +package machineadvancer + +import ( + "github.com/cartesi/rollups-node/internal/node/model" + "github.com/cartesi/rollups-node/internal/node/nodemachine" +) + +type Repository interface { + // Only needs Id and RawData fields from model.Input. + GetInputs(appAddresses []model.Address) (map[model.Address][]model.Input, error) + + Store(model.Input, *nodemachine.AdvanceResponse) error +} diff --git a/internal/node/nodemachine/machine.go b/internal/node/nodemachine/machine.go new file mode 100644 index 000000000..029b81901 --- /dev/null +++ b/internal/node/nodemachine/machine.go @@ -0,0 +1,212 @@ +// (c) Cartesi and individual authors (see AUTHORS) +// SPDX-License-Identifier: Apache-2.0 (see LICENSE) + +package nodemachine + +import ( + "context" + "errors" + "time" + + "github.com/cartesi/rollups-node/internal/node/model" + "github.com/cartesi/rollups-node/internal/node/nodemachine/pmutex" + "github.com/cartesi/rollups-node/pkg/rollupsmachine" + + "golang.org/x/sync/semaphore" +) + +var ErrTimeLimitExceeded = errors.New("time limit exceeded") + +type AdvanceResponse struct { + Status model.InputCompletionStatus + Outputs [][]byte + Reports [][]byte + OutputsHash model.Hash + MachineHash model.Hash +} + +func (res AdvanceResponse) StatusOk() bool { + return res.Status == model.InputStatusAccepted || res.Status == model.InputStatusRejected +} + +type InspectResponse struct { + Accepted bool + Reports [][]byte + Err error +} + +type RollupsMachine interface { + Fork() (RollupsMachine, error) + Close() error + + Hash() (model.Hash, error) + + Advance([]byte) (bool, [][]byte, [][]byte, model.Hash, error) + Inspect([]byte) (bool, [][]byte, error) +} + +type NodeMachine struct { + RollupsMachine + + // Timeout in seconds. + timeout time.Duration + + // Ensures advance/inspect mutual exclusion when accessing the inner RollupsMachine. + // Advances have a higher priority than Inspects to acquire the lock. + mutex *pmutex.PMutex + + // Controls how many inspects can be concurrently active. + inspects *semaphore.Weighted +} + +func New(rollupsMachine RollupsMachine, maxConcurrentInspects int8) *NodeMachine { + return &NodeMachine{ + RollupsMachine: rollupsMachine, + mutex: pmutex.New(), + inspects: semaphore.NewWeighted(int64(maxConcurrentInspects)), + } +} + +func (machine *NodeMachine) Advance(ctx context.Context, input []byte) (*AdvanceResponse, error) { + var fork RollupsMachine + var err error + + { // Forks the machine. + machine.mutex.HLock() + defer machine.mutex.Unlock() + fork, err = machine.Fork() + if err != nil { + return nil, err + } + } + + // Sends the advance-state request to the forked machine. + res, err, timedOut := runWithTimeout(ctx, machine.timeout, func() (*AdvanceResponse, error) { + accepted, outputs, reports, outputsHash, err := fork.Advance(input) + status, err := toInputStatus(accepted, err) + if err != nil { + return nil, err + } + return &AdvanceResponse{ + Status: status, + Outputs: outputs, + Reports: reports, + OutputsHash: outputsHash, + }, nil + }) + if err != nil { + goto end + } + if timedOut { + res = &AdvanceResponse{Status: model.InputStatusTimeLimitExceeded} + } + + // Only gets the post-advance machine hash if the request was accepted. + if res.Status == model.InputStatusAccepted { + res.MachineHash, err = fork.Hash() + if err != nil { + goto end + } + } + + // If the forked machine is in a valid state: + if res.StatusOk() { + // Switches the current machine and the forked machine. + machine.mutex.HLock() + defer machine.mutex.Unlock() + fork, machine.RollupsMachine = machine.RollupsMachine, fork + } + +end: + return res, errors.Join(err, fork.Close()) +} + +func (machine *NodeMachine) Inspect(ctx context.Context, query []byte) (*InspectResponse, error) { + // Controls how many inspects can be concurrently active. + err := machine.inspects.Acquire(ctx, 1) + if err != nil { + return nil, err + } + defer machine.inspects.Release(1) + + var fork RollupsMachine + + { // Forks the machine. + machine.mutex.LLock() + defer machine.mutex.Unlock() + fork, err = machine.RollupsMachine.Fork() + if err != nil { + return nil, err + } + } + + // Sends the inspect-state request to the forked machine. + res, _, timedOut := runWithTimeout(ctx, machine.timeout, func() (*InspectResponse, error) { + accepted, reports, err := fork.Inspect(query) + return &InspectResponse{Accepted: accepted, Reports: reports, Err: err}, nil + }) + if timedOut { + res = &InspectResponse{Err: ErrTimeLimitExceeded} + } + + return res, fork.Close() +} + +// ------------------------------------------------------------------------------------------------ + +func toInputStatus(accepted bool, err error) (status model.InputCompletionStatus, _ error) { + switch err { + case nil: + if accepted { + return model.InputStatusAccepted, nil + } else { + return model.InputStatusRejected, nil + } + case rollupsmachine.ErrException: + return model.InputStatusException, nil + case rollupsmachine.ErrHalted: + return model.InputStatusMachineHalted, nil + case rollupsmachine.ErrCycleLimitExceeded: + return model.InputStatusCycleLimitExceeded, nil + case rollupsmachine.ErrOutputsLimitExceeded: + panic("TODO") + case rollupsmachine.ErrCartesiMachine, + rollupsmachine.ErrProgress, + rollupsmachine.ErrSoftYield: + return status, err + default: + return status, err + } + + // ErrPayloadLengthLimitExceeded + // InputStatusPayloadLengthLimitExceeded +} + +func runWithTimeout[T any]( + ctx context.Context, + timeout time.Duration, + f func() (*T, error), +) (_ *T, _ error, timedOut bool) { + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + success := make(chan *T, 1) + failure := make(chan error, 1) + go func() { + t, err := f() + if err != nil { + failure <- err + } else { + success <- t + } + }() + + select { + case <-ctx.Done(): + return nil, nil, true + case t := <-success: + return t, nil, false + case err := <-failure: + return nil, err, false + } +} diff --git a/internal/node/nodemachine/pmutex/pmutex.go b/internal/node/nodemachine/pmutex/pmutex.go new file mode 100644 index 000000000..585e7555a --- /dev/null +++ b/internal/node/nodemachine/pmutex/pmutex.go @@ -0,0 +1,56 @@ +// (c) Cartesi and individual authors (see AUTHORS) +// SPDX-License-Identifier: Apache-2.0 (see LICENSE) + +package pmutex + +import ( + "sync" + "sync/atomic" +) + +// A PMutex is a mutual exclusion lock with priority capabilities. +// A call to HLock always acquires the mutex before LLock. +type PMutex struct { + // Main mutex. + mutex *sync.Mutex + + // Condition variable for the waiting low-priority threads. + waitingLow *sync.Cond + + // Quantity of high-priority threads waiting to acquire the lock. + waitingHigh *atomic.Int32 +} + +// New creates a new PMutex. +func New() *PMutex { + mutex := &sync.Mutex{} + return &PMutex{ + mutex: mutex, + waitingLow: sync.NewCond(mutex), + waitingHigh: &atomic.Int32{}, + } +} + +// HLock acquires the mutex for high-priority threads. +func (lock *PMutex) HLock() { + lock.waitingHigh.Add(1) + lock.mutex.Lock() + lock.waitingHigh.Add(-1) +} + +// LLock acquires the mutex for low-priority threads. +// (It waits until there are no high-priority threads trying to acquire the lock.) +func (lock *PMutex) LLock() { + lock.mutex.Lock() + for lock.waitingHigh.Load() != 0 { + // NOTE: a cond.Wait() releases the lock uppon being called + // and tries to acquire it after being awakened. + lock.waitingLow.Wait() + } +} + +// Unlock releases the mutex for both types of threads. +func (lock *PMutex) Unlock() { + lock.waitingLow.Broadcast() + lock.mutex.Unlock() +}