From 915eb8e794a6b60ad684c7dcc5189dc700247349 Mon Sep 17 00:00:00 2001 From: Renan Santos Date: Fri, 19 Jul 2024 13:45:39 -0300 Subject: [PATCH] feat: add advancer package --- internal/node/advancer/advancer.go | 108 +++++++++ internal/node/advancer/advancer_test.go | 171 +++++++++++++ internal/node/advancer/service/service.go | 64 +++++ internal/node/machine/nodemachine/machine.go | 213 ++++++++++++++++ .../node/machine/nodemachine/pmutex/pmutex.go | 56 +++++ internal/repository/advancer.go | 196 +++++++++++++++ internal/repository/base.go | 17 +- ...nput_claim_output_report_nodeconfig.up.sql | 4 +- internal/repository/schemamanager.go | 4 + test/advancer/advancer_test.go | 228 ++++++++++++++++++ 10 files changed, 1055 insertions(+), 6 deletions(-) create mode 100644 internal/node/advancer/advancer.go create mode 100644 internal/node/advancer/advancer_test.go create mode 100644 internal/node/advancer/service/service.go create mode 100644 internal/node/machine/nodemachine/machine.go create mode 100644 internal/node/machine/nodemachine/pmutex/pmutex.go create mode 100644 internal/repository/advancer.go create mode 100644 test/advancer/advancer_test.go diff --git a/internal/node/advancer/advancer.go b/internal/node/advancer/advancer.go new file mode 100644 index 000000000..a640c9af6 --- /dev/null +++ b/internal/node/advancer/advancer.go @@ -0,0 +1,108 @@ +// (c) Cartesi and individual authors (see AUTHORS) +// SPDX-License-Identifier: Apache-2.0 (see LICENSE) + +package advancer + +import ( + "context" + "errors" + "fmt" + "log/slog" + "time" + + "github.com/cartesi/rollups-node/internal/node/advancer/service" + "github.com/cartesi/rollups-node/internal/node/machine/nodemachine" + . "github.com/cartesi/rollups-node/internal/node/model" +) + +var ( + ErrInvalidMachines = errors.New("machines must not be nil") + ErrInvalidRepository = errors.New("repository must not be nil") + ErrInvalidAddress = errors.New("no machine for address") +) + +type Advancer struct { + machines Machines + repository Repository +} + +func New(machines Machines, repository Repository) (*Advancer, error) { + if machines == nil { + return nil, ErrInvalidMachines + } + if repository == nil { + return nil, ErrInvalidRepository + } + return &Advancer{machines: machines, repository: repository}, nil +} + +func (advancer *Advancer) Poller(pollingInterval time.Duration) (*service.Poller, error) { + return service.NewPoller("advancer", advancer, pollingInterval) +} + +func (advancer *Advancer) Run(ctx context.Context) error { + appAddresses := keysToSlice(advancer.machines) + + // Gets the unprocessed inputs (of all apps) from the repository. + slog.Info("advancer: getting unprocessed inputs") + inputs, err := advancer.repository.GetInputs(ctx, appAddresses) + if err != nil { + return err + } + + // Processes each set of inputs. + for appAddress, inputs := range inputs { + slog.Info(fmt.Sprintf("advancer: processing %d input(s) from %v", len(inputs), appAddress)) + + machine, ok := advancer.machines[appAddress] + if !ok { + return fmt.Errorf("%w %s", ErrInvalidAddress, appAddress.String()) + } + + // Processes inputs from the same application sequentially. + for _, input := range inputs { + slog.Info("advancer: processing input", "id", input.Id) + + res, err := machine.Advance(ctx, input.RawData) + if err != nil { + return err + } + + err = advancer.repository.StoreResults(ctx, input, res) + if err != nil { + return err + } + } + } + + return nil +} + +// ------------------------------------------------------------------------------------------------ + +type Repository interface { + // Only needs Id and RawData fields from model.Input. + GetInputs(context.Context, []Address) (map[Address][]*Input, error) + + StoreResults(context.Context, *Input, *nodemachine.AdvanceResult) error +} + +// A map of application addresses to machines. +type Machines = map[Address]Machine + +type Machine interface { + Advance(context.Context, []byte) (*nodemachine.AdvanceResult, error) +} + +// ------------------------------------------------------------------------------------------------ + +// keysToSlice returns a slice with the keys of a map. +func keysToSlice[K comparable, V any](m map[K]V) []K { + keys := make([]K, len(m)) + i := 0 + for k := range m { + keys[i] = k + i++ + } + return keys +} diff --git a/internal/node/advancer/advancer_test.go b/internal/node/advancer/advancer_test.go new file mode 100644 index 000000000..466a6530a --- /dev/null +++ b/internal/node/advancer/advancer_test.go @@ -0,0 +1,171 @@ +// (c) Cartesi and individual authors (see AUTHORS) +// SPDX-License-Identifier: Apache-2.0 (see LICENSE) + +package advancer + +import ( + "context" + crand "crypto/rand" + mrand "math/rand" + "testing" + + "github.com/cartesi/rollups-node/internal/node/machine/nodemachine" + . "github.com/cartesi/rollups-node/internal/node/model" + + "github.com/stretchr/testify/suite" +) + +func TestAdvancer(t *testing.T) { + suite.Run(t, new(AdvancerSuite)) +} + +type AdvancerSuite struct{ suite.Suite } + +func (s *AdvancerSuite) TestNew() { + s.Run("Ok", func() { + require := s.Require() + machines := Machines{randomAddress(): &MockMachine{}} + repository := &MockRepository{} + advancer, err := New(machines, repository) + require.NotNil(advancer) + require.Nil(err) + }) + + s.Run("InvalidMachines", func() { + require := s.Require() + repository := &MockRepository{} + advancer, err := New(nil, repository) + require.Nil(advancer) + require.Equal(ErrInvalidMachines, err) + }) + + s.Run("InvalidRepository", func() { + require := s.Require() + machines := Machines{randomAddress(): &MockMachine{}} + advancer, err := New(machines, nil) + require.Nil(advancer) + require.Equal(ErrInvalidRepository, err) + }) +} + +// NOTE: this test is just the beginning; we need more tests. +func (s *AdvancerSuite) TestRun() { + require := s.Require() + + appAddress := randomAddress() + + machines := Machines{} + advanceRes := randomAdvanceResult() + machines[appAddress] = &MockMachine{AdvanceVal: advanceRes, AdvanceErr: nil} + + repository := &MockRepository{ + GetInputsVal: map[Address][]*Input{appAddress: randomInputs(1)}, + GetInputsErr: nil, + StoreResultsErr: nil, + } + + advancer, err := New(machines, repository) + require.NotNil(advancer) + require.Nil(err) + + err = advancer.Run(context.Background()) + require.Nil(err) + + require.Len(repository.Stored, 1) + require.Equal(advanceRes, repository.Stored[0]) +} + +// ------------------------------------------------------------------------------------------------ + +type MockMachine struct { + AdvanceVal *nodemachine.AdvanceResult + AdvanceErr error +} + +func (mock *MockMachine) Advance(_ context.Context, _ []byte) (*nodemachine.AdvanceResult, error) { + return mock.AdvanceVal, mock.AdvanceErr +} + +// ------------------------------------------------------------------------------------------------ + +type MockRepository struct { + GetInputsVal map[Address][]*Input + GetInputsErr error + StoreResultsErr error + + Stored []*nodemachine.AdvanceResult +} + +func (mock *MockRepository) GetInputs( + _ context.Context, + appAddresses []Address, +) (map[Address][]*Input, error) { + return mock.GetInputsVal, mock.GetInputsErr +} + +func (mock *MockRepository) StoreResults( + _ context.Context, + input *Input, + res *nodemachine.AdvanceResult, +) error { + mock.Stored = append(mock.Stored, res) + return mock.StoreResultsErr +} + +// ------------------------------------------------------------------------------------------------ + +func randomAddress() Address { + address := make([]byte, 20) + _, err := crand.Read(address) + if err != nil { + panic(err) + } + return Address(address) +} + +func randomHash() Hash { + hash := make([]byte, 32) + _, err := crand.Read(hash) + if err != nil { + panic(err) + } + return Hash(hash) +} + +func randomBytes() []byte { + size := mrand.Intn(100) + 1 + bytes := make([]byte, size) + _, err := crand.Read(bytes) + if err != nil { + panic(err) + } + return bytes +} + +func randomSliceOfBytes() [][]byte { + size := mrand.Intn(10) + 1 + slice := make([][]byte, size) + for i := 0; i < size; i++ { + slice[i] = randomBytes() + } + return slice +} + +func randomInputs(size int) []*Input { + slice := make([]*Input, size) + for i := 0; i < size; i++ { + slice[i] = &Input{Id: uint64(i), RawData: randomBytes()} + } + return slice + +} + +func randomAdvanceResult() *nodemachine.AdvanceResult { + return &nodemachine.AdvanceResult{ + Status: InputStatusAccepted, + Outputs: randomSliceOfBytes(), + Reports: randomSliceOfBytes(), + OutputsHash: randomHash(), + MachineHash: randomHash(), + } +} diff --git a/internal/node/advancer/service/service.go b/internal/node/advancer/service/service.go new file mode 100644 index 000000000..495dd25f6 --- /dev/null +++ b/internal/node/advancer/service/service.go @@ -0,0 +1,64 @@ +// (c) Cartesi and individual authors (see AUTHORS) +// SPDX-License-Identifier: Apache-2.0 (see LICENSE) + +package service + +import ( + "context" + "errors" + "fmt" + "log/slog" + "sync/atomic" + "time" +) + +type Service interface { + Run(context.Context) error +} + +type Poller struct { + name string + service Service + shouldStop atomic.Bool + ticker *time.Ticker +} + +var ErrInvalidPollingInterval = errors.New("polling interval must be greater than zero") + +func NewPoller(name string, service Service, pollingInterval time.Duration) (*Poller, error) { + if pollingInterval <= 0 { + return nil, ErrInvalidPollingInterval + } + ticker := time.NewTicker(pollingInterval) + return &Poller{name: name, service: service, ticker: ticker}, nil +} + +func (poller *Poller) Start(ctx context.Context, ready chan<- struct{}) error { + ready <- struct{}{} + + slog.Info(fmt.Sprintf("%s: started", poller.name)) + + for { + // Runs the service's inner routine. + err := poller.service.Run(ctx) + if err != nil { + return err + } + + // Checks if the service was ordered to stop. + if poller.shouldStop.Load() { + poller.shouldStop.Store(false) + slog.Info(fmt.Sprintf("%s: stopped", poller.name)) + return nil + } + + // Waits for the polling interval to elapse. + slog.Info(fmt.Sprintf("%s: waiting for the polling interval to elapse", poller.name)) + <-poller.ticker.C + } +} + +// Stop orders the service to stop, which will happen before the next poll. +func (poller *Poller) Stop() { + poller.shouldStop.Store(true) +} diff --git a/internal/node/machine/nodemachine/machine.go b/internal/node/machine/nodemachine/machine.go new file mode 100644 index 000000000..9b328e7e4 --- /dev/null +++ b/internal/node/machine/nodemachine/machine.go @@ -0,0 +1,213 @@ +// (c) Cartesi and individual authors (see AUTHORS) +// SPDX-License-Identifier: Apache-2.0 (see LICENSE) + +package nodemachine + +import ( + "context" + "errors" + "time" + + "github.com/cartesi/rollups-node/internal/node/machine/nodemachine/pmutex" + "github.com/cartesi/rollups-node/internal/node/model" + "github.com/cartesi/rollups-node/pkg/rollupsmachine" + + "golang.org/x/sync/semaphore" +) + +var ErrTimeLimitExceeded = errors.New("time limit exceeded") + +type AdvanceResult struct { + Status model.InputCompletionStatus + Outputs [][]byte + Reports [][]byte + OutputsHash model.Hash + MachineHash model.Hash +} + +func (res AdvanceResult) StatusOk() bool { + return res.Status == model.InputStatusAccepted || res.Status == model.InputStatusRejected +} + +type InspectResult struct { + Accepted bool + Reports [][]byte + Err error +} + +type RollupsMachine interface { + Fork() (*rollupsmachine.RollupsMachine, string, error) // NOTE: returns the concrete type + Close() error + + Hash() (model.Hash, error) + + Advance([]byte) (bool, [][]byte, [][]byte, model.Hash, error) + Inspect([]byte) (bool, [][]byte, error) +} + +type NodeMachine struct { + RollupsMachine + + // Timeout in seconds. + timeout time.Duration + + // Ensures advance/inspect mutual exclusion when accessing the inner RollupsMachine. + // Advances have a higher priority than Inspects to acquire the lock. + mutex *pmutex.PMutex + + // Controls how many inspects can be concurrently active. + inspects *semaphore.Weighted +} + +func New( + rollupsMachine RollupsMachine, + timeout time.Duration, + maxConcurrentInspects int8, +) *NodeMachine { + return &NodeMachine{ + RollupsMachine: rollupsMachine, + timeout: timeout, + mutex: pmutex.New(), + inspects: semaphore.NewWeighted(int64(maxConcurrentInspects)), + } +} + +func (machine *NodeMachine) Advance(ctx context.Context, input []byte) (*AdvanceResult, error) { + var fork RollupsMachine + var err error + + // Forks the machine. + machine.mutex.HLock() + fork, _, err = machine.Fork() + machine.mutex.Unlock() + if err != nil { + return nil, err + } + + // Sends the advance-state request to the forked machine. + accepted, outputs, reports, outputsHash, err := fork.Advance(input) + status, err := toInputStatus(accepted, err) + if err != nil { + return nil, errors.Join(err, fork.Close()) + } + + res := &AdvanceResult{ + Status: status, + Outputs: outputs, + Reports: reports, + OutputsHash: outputsHash, + } + + // Only gets the post-advance machine hash if the request was accepted. + if status == model.InputStatusAccepted { + res.MachineHash, err = fork.Hash() + if err != nil { + return nil, errors.Join(err, fork.Close()) + } + } + + // If the forked machine is in a valid state: + if res.StatusOk() { + // Closes the current machine. + err = machine.RollupsMachine.Close() + // Replaces the current machine with the fork. + machine.mutex.HLock() + machine.RollupsMachine = fork + machine.mutex.Unlock() + } else { + // Closes the forked machine. + err = fork.Close() + } + + return res, err +} + +func (machine *NodeMachine) Inspect(ctx context.Context, query []byte) (*InspectResult, error) { + // Controls how many inspects can be concurrently active. + err := machine.inspects.Acquire(ctx, 1) + if err != nil { + return nil, err + } + defer machine.inspects.Release(1) + + var fork RollupsMachine + + // Forks the machine. + machine.mutex.LLock() + fork, _, err = machine.RollupsMachine.Fork() + machine.mutex.Unlock() + if err != nil { + return nil, err + } + + // Sends the inspect-state request to the forked machine. + res, _, timedOut := runWithTimeout(ctx, machine.timeout, func() (*InspectResult, error) { + accepted, reports, err := fork.Inspect(query) + return &InspectResult{Accepted: accepted, Reports: reports, Err: err}, nil + }) + if timedOut { + res = &InspectResult{Err: ErrTimeLimitExceeded} + } + + return res, fork.Close() +} + +// ------------------------------------------------------------------------------------------------ + +func toInputStatus(accepted bool, err error) (status model.InputCompletionStatus, _ error) { + switch err { + case nil: + if accepted { + return model.InputStatusAccepted, nil + } else { + return model.InputStatusRejected, nil + } + case rollupsmachine.ErrException: + return model.InputStatusException, nil + case rollupsmachine.ErrHalted: + return model.InputStatusMachineHalted, nil + case rollupsmachine.ErrCycleLimitExceeded: + return model.InputStatusCycleLimitExceeded, nil + case rollupsmachine.ErrOutputsLimitExceeded: + panic("TODO") + case rollupsmachine.ErrCartesiMachine, + rollupsmachine.ErrProgress, + rollupsmachine.ErrSoftYield: + return status, err + default: + return status, err + } + + // ErrPayloadLengthLimitExceeded + // InputStatusPayloadLengthLimitExceeded +} + +// Unused. +func runWithTimeout[T any]( + ctx context.Context, + timeout time.Duration, + f func() (*T, error), +) (_ *T, _ error, timedOut bool) { + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + success := make(chan *T, 1) + failure := make(chan error, 1) + go func() { + t, err := f() + if err != nil { + failure <- err + } else { + success <- t + } + }() + + select { + case <-ctx.Done(): + return nil, nil, true + case t := <-success: + return t, nil, false + case err := <-failure: + return nil, err, false + } +} diff --git a/internal/node/machine/nodemachine/pmutex/pmutex.go b/internal/node/machine/nodemachine/pmutex/pmutex.go new file mode 100644 index 000000000..b494c9c0a --- /dev/null +++ b/internal/node/machine/nodemachine/pmutex/pmutex.go @@ -0,0 +1,56 @@ +// (c) Cartesi and individual authors (see AUTHORS) +// SPDX-License-Identifier: Apache-2.0 (see LICENSE) + +package pmutex + +import ( + "sync" + "sync/atomic" +) + +// A PMutex is a mutual exclusion lock with priority capabilities. +// A call to HLock always acquires the mutex before LLock. +type PMutex struct { + // Main mutex. + mutex *sync.Mutex + + // Condition variable for the waiting low-priority threads. + waitingLow *sync.Cond + + // Quantity of high-priority threads waiting to acquire the lock. + waitingHigh *atomic.Int32 +} + +// New creates a new PMutex. +func New() *PMutex { + mutex := &sync.Mutex{} + return &PMutex{ + mutex: mutex, + waitingLow: sync.NewCond(mutex), + waitingHigh: &atomic.Int32{}, + } +} + +// HLock acquires the mutex for high-priority threads. +func (pmutex *PMutex) HLock() { + pmutex.waitingHigh.Add(1) + pmutex.mutex.Lock() + pmutex.waitingHigh.Add(-1) +} + +// LLock acquires the mutex for low-priority threads. +// (It waits until there are no high-priority threads trying to acquire the lock.) +func (pmutex *PMutex) LLock() { + pmutex.mutex.Lock() + for pmutex.waitingHigh.Load() != 0 { + // NOTE: a cond.Wait() releases the lock uppon being called + // and tries to acquire it after being awakened. + pmutex.waitingLow.Wait() + } +} + +// Unlock releases the mutex for both types of threads. +func (pmutex *PMutex) Unlock() { + pmutex.waitingLow.Broadcast() + pmutex.mutex.Unlock() +} diff --git a/internal/repository/advancer.go b/internal/repository/advancer.go new file mode 100644 index 000000000..f072a2a1b --- /dev/null +++ b/internal/repository/advancer.go @@ -0,0 +1,196 @@ +// (c) Cartesi and individual authors (see AUTHORS) +// SPDX-License-Identifier: Apache-2.0 (see LICENSE) + +package repository + +import ( + "context" + "errors" + "fmt" + "strings" + + "github.com/cartesi/rollups-node/internal/node/machine/nodemachine" + . "github.com/cartesi/rollups-node/internal/node/model" + "github.com/jackc/pgx/v5" +) + +var ErrAdvancerRepository = errors.New("advancer repository error") + +type AdvancerRepository struct{ *Database } + +func (repository *AdvancerRepository) GetInputs( + ctx context.Context, + appAddresses []Address, +) (map[Address][]*Input, error) { + result := map[Address][]*Input{} + if len(appAddresses) == 0 { + return result, nil + } + + query := fmt.Sprintf(` + SELECT id, application_address, raw_data + FROM input + WHERE status = 'NONE' + AND application_address IN %s + ORDER BY index ASC, application_address + `, toIN(appAddresses)) // TODO: not sanitized + rows, err := repository.db.Query(ctx, query) + if err != nil { + return nil, fmt.Errorf("%w (failed querying inputs): %w", ErrAdvancerRepository, err) + } + + var input Input + scans := []any{&input.Id, &input.AppAddress, &input.RawData} + _, err = pgx.ForEachRow(rows, scans, func() error { + input := input + if _, ok := result[input.AppAddress]; ok { + result[input.AppAddress] = append(result[input.AppAddress], &input) + } else { + result[input.AppAddress] = []*Input{&input} + } + return nil + }) + if err != nil { + return nil, fmt.Errorf("%w (failed reading input rows): %w", ErrAdvancerRepository, err) + } + + return result, nil +} + +func (repository *AdvancerRepository) StoreResults( + ctx context.Context, + input *Input, + res *nodemachine.AdvanceResult, +) error { + tx, err := repository.db.Begin(ctx) + if err != nil { + return errors.Join(ErrBeginTx, err) + } + + // Inserts the outputs. + nextOutputIndex, err := repository.getNextIndex(ctx, tx, "output", input.AppAddress) + if err != nil { + return err + } + err = repository.insert(ctx, tx, "output", res.Outputs, input.Id, nextOutputIndex) + if err != nil { + return err + } + + // Inserts the reports. + nextReportIndex, err := repository.getNextIndex(ctx, tx, "report", input.AppAddress) + if err != nil { + return err + } + err = repository.insert(ctx, tx, "report", res.Reports, input.Id, nextReportIndex) + if err != nil { + return err + } + + // Updates the input's status. + err = repository.updateInput(ctx, tx, input.Id, res.Status, res.OutputsHash, res.MachineHash) + if err != nil { + return err + } + + err = tx.Commit(ctx) + if err != nil { + return errors.Join(ErrCommitTx, err, tx.Rollback(ctx)) + } + + return nil +} + +// ------------------------------------------------------------------------------------------------ + +func (_ *AdvancerRepository) getNextIndex( + ctx context.Context, + tx pgx.Tx, + tableName string, + appAddress Address, +) (uint64, error) { + var nextIndex uint64 + query := fmt.Sprintf(` + SELECT COALESCE(MAX(%s.index) + 1, 0) + FROM input INNER JOIN %s ON input.id = %s.input_id + WHERE input.status = 'ACCEPTED' + AND input.application_address = $1 + `, tableName, tableName, tableName) + err := tx.QueryRow(ctx, query, appAddress).Scan(&nextIndex) + if err != nil { + err = fmt.Errorf("failed to get the next %s index: %w", tableName, err) + return 0, errors.Join(err, tx.Rollback(ctx)) + } + return nextIndex, nil +} + +func (_ *AdvancerRepository) insert( + ctx context.Context, + tx pgx.Tx, + tableName string, + dataArray [][]byte, + inputId uint64, + nextIndex uint64, +) error { + lenOutputs := int64(len(dataArray)) + if lenOutputs < 1 { + return nil + } + + rows := [][]any{} + for i, data := range dataArray { + rows = append(rows, []any{inputId, nextIndex + uint64(i), data}) + } + + count, err := tx.CopyFrom( + context.Background(), + pgx.Identifier{tableName}, + []string{"input_id", "index", "raw_data"}, + pgx.CopyFromRows(rows), + ) + if err != nil { + return errors.Join(ErrCopyFrom, err, tx.Rollback(ctx)) + } + if lenOutputs != count { + err := fmt.Errorf("not all %ss were inserted (%d != %d)", tableName, lenOutputs, count) + return errors.Join(err, tx.Rollback(ctx)) + } + + return nil +} + +func (_ *AdvancerRepository) updateInput( + ctx context.Context, + tx pgx.Tx, + inputId uint64, + status InputCompletionStatus, + outputsHash Hash, + machineHash Hash, +) error { + query := ` + UPDATE input + SET (status, outputs_hash, machine_hash) = (@status, @outputsHash, @machineHash) + WHERE id = @id + ` + args := pgx.NamedArgs{ + "status": status, + "outputsHash": outputsHash, + "machineHash": machineHash, + "id": inputId, + } + _, err := tx.Exec(ctx, query, args) + if err != nil { + return errors.Join(ErrUpdateRow, err, tx.Rollback(ctx)) + } + return nil +} + +// ------------------------------------------------------------------------------------------------ + +func toIN[T fmt.Stringer](a []T) string { + s := []string{} + for _, x := range a { + s = append(s, fmt.Sprintf("'\\x%s'", x.String()[2:])) + } + return fmt.Sprintf("(%s)", strings.Join(s, ", ")) +} diff --git a/internal/repository/base.go b/internal/repository/base.go index 9bfc6dac4..ed92ff81d 100644 --- a/internal/repository/base.go +++ b/internal/repository/base.go @@ -19,7 +19,14 @@ type Database struct { db *pgxpool.Pool } -var ErrInsertRow = errors.New("unable to insert row") +var ( + ErrInsertRow = errors.New("unable to insert row") + ErrUpdateRow = errors.New("unable to update row") + ErrCopyFrom = errors.New("unable to COPY FROM") + + ErrBeginTx = errors.New("unable to begin transaction") + ErrCommitTx = errors.New("unable to commit transaction") +) func Connect( ctx context.Context, @@ -141,8 +148,10 @@ func (pg *Database) InsertInput( @blockNumber, @machineHash, @outputsHash, - @applicationAddress)` - + @applicationAddress) + RETURNING + id + ` args := pgx.NamedArgs{ "index": input.Index, "status": input.CompletionStatus, @@ -153,7 +162,7 @@ func (pg *Database) InsertInput( "applicationAddress": input.AppAddress, } - _, err := pg.db.Exec(ctx, query, args) + err := pg.db.QueryRow(ctx, query, args).Scan(&input.Id) if err != nil { return fmt.Errorf("%w: %w", ErrInsertRow, err) } diff --git a/internal/repository/migrations/000001_create_application_input_claim_output_report_nodeconfig.up.sql b/internal/repository/migrations/000001_create_application_input_claim_output_report_nodeconfig.up.sql index 849c0e786..507a63443 100644 --- a/internal/repository/migrations/000001_create_application_input_claim_output_report_nodeconfig.up.sql +++ b/internal/repository/migrations/000001_create_application_input_claim_output_report_nodeconfig.up.sql @@ -64,7 +64,7 @@ CREATE TABLE "output" CONSTRAINT "output_input_id_fkey" FOREIGN KEY ("input_id") REFERENCES "input"("id") ); -CREATE UNIQUE INDEX "output_idx" ON "output"("index"); +CREATE INDEX "output_idx" ON "output"("index"); CREATE TABLE "report" ( @@ -76,7 +76,7 @@ CREATE TABLE "report" CONSTRAINT "report_input_id_fkey" FOREIGN KEY ("input_id") REFERENCES "input"("id") ); -CREATE UNIQUE INDEX "report_idx" ON "report"("index"); +CREATE INDEX "report_idx" ON "report"("index"); CREATE TABLE "node_config" ( diff --git a/internal/repository/schemamanager.go b/internal/repository/schemamanager.go index c0b2d2242..397c8fe75 100644 --- a/internal/repository/schemamanager.go +++ b/internal/repository/schemamanager.go @@ -79,6 +79,10 @@ func (s *SchemaManager) Upgrade() error { return nil } +func (s *SchemaManager) DeleteAll() error { + return s.migrate.Down() +} + func (s *SchemaManager) Close() { source, db := s.migrate.Close() if source != nil { diff --git a/test/advancer/advancer_test.go b/test/advancer/advancer_test.go new file mode 100644 index 000000000..5730f7efd --- /dev/null +++ b/test/advancer/advancer_test.go @@ -0,0 +1,228 @@ +// (c) Cartesi and individual authors (see AUTHORS) +// SPDX-License-Identifier: Apache-2.0 (see LICENSE) + +package advancer + +import ( + "context" + "fmt" + "os" + "testing" + "time" + + "github.com/cartesi/rollups-node/internal/node/advancer" + "github.com/cartesi/rollups-node/internal/node/machine/nodemachine" + "github.com/cartesi/rollups-node/internal/node/model" + "github.com/cartesi/rollups-node/internal/repository" + "github.com/cartesi/rollups-node/pkg/emulator" + "github.com/cartesi/rollups-node/pkg/rollupsmachine" + "github.com/cartesi/rollups-node/test/snapshot" + "github.com/ethereum/go-ethereum/common" + "github.com/stretchr/testify/require" + "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/modules/postgres" + "github.com/testcontainers/testcontainers-go/wait" +) + +var appAddress model.Address + +func TestAdvancer(t *testing.T) { + require := require.New(t) + + // Creates the snapshot. + script := "ioctl-echo-loop --vouchers=1 --notices=3 --reports=5 --verbose=1" + snapshot, err := snapshot.FromScript(script, uint64(1_000_000_000)) + require.Nil(err) + defer func() { require.Nil(snapshot.Close()) }() + + // Starts the server. + verbosity := rollupsmachine.ServerVerbosityInfo + address, err := rollupsmachine.StartServer(verbosity, 0, os.Stdout, os.Stderr) + require.Nil(err) + + // Loads the rollupsmachine. + config := &emulator.MachineRuntimeConfig{} + rollupsMachine, err := rollupsmachine.Load(snapshot.Dir, address, config) + require.Nil(err) + require.NotNil(rollupsMachine) + + // Wraps the rollupsmachine with nodemachine. + nodeMachine := nodemachine.New(rollupsMachine, time.Minute, 10) + require.Nil(err) + require.NotNil(nodeMachine) + defer func() { require.Nil(nodeMachine.Close()) }() + + // Creates the machine pool. + appAddress = common.HexToAddress("deadbeef") + machines := advancer.Machines{appAddress: nodeMachine} + + // Creates the background context. + ctx := context.Background() + + // Create the database container. + databaseContainer, err := newDatabaseContainer(ctx) + require.Nil(err) + defer func() { require.Nil(databaseContainer.Terminate(ctx)) }() + + // Setups the database. + database, err := newDatabase(ctx, databaseContainer) + require.Nil(err) + err = populateDatabase(ctx, database) + require.Nil(err) + defer database.Close() + + // Creates the advancer's repository. + repository := &repository.AdvancerRepository{Database: database} + + // Creates the advancer. + advancer, err := advancer.New(machines, repository) + require.Nil(err) + require.NotNil(advancer) + poller, err := advancer.Poller(5 * time.Second) + require.Nil(err) + require.NotNil(poller) + + // Starts the advancer in a separate goroutine. + done := make(chan struct{}, 1) + go func() { + ready := make(chan struct{}, 1) + err = poller.Start(ctx, ready) + <-ready + require.Nil(err, "%v", err) + done <- struct{}{} + }() + + // Orders the advancer to stop after some time has passed. + time.Sleep(5 * time.Second) + poller.Stop() + +wait: + for { + select { + case <-done: + fmt.Println("Done!") + break wait + default: + fmt.Println("Waiting...") + time.Sleep(time.Second) + } + } +} + +func newDatabaseContainer(ctx context.Context) (*postgres.PostgresContainer, error) { + dbName := "cartesinode" + dbUser := "admin" + dbPassword := "password" + + // Start the postgres container + container, err := postgres.Run( + ctx, + "postgres:16-alpine", + postgres.WithDatabase(dbName), + postgres.WithUsername(dbUser), + postgres.WithPassword(dbPassword), + testcontainers.WithWaitStrategy( + wait.ForLog("database system is ready to accept connections"). + WithOccurrence(2). + WithStartupTimeout(10*time.Second)), + ) + + return container, err +} + +func newLocalDatabase(ctx context.Context) (*repository.Database, error) { + endpoint := "postgres://renan:renan@localhost:5432/renan?sslmode=disable" + + schemaManager, err := repository.NewSchemaManager(endpoint) + if err != nil { + return nil, err + } + + err = schemaManager.DeleteAll() + if err != nil { + return nil, err + } + + err = schemaManager.Upgrade() + if err != nil { + return nil, err + } + + database, err := repository.Connect(ctx, endpoint) + if err != nil { + return nil, err + } + + return database, nil +} + +func newDatabase( + ctx context.Context, + container *postgres.PostgresContainer, +) (*repository.Database, error) { + endpoint, err := container.ConnectionString(ctx, "sslmode=disable") + if err != nil { + return nil, err + } + + schemaManager, err := repository.NewSchemaManager(endpoint) + if err != nil { + return nil, err + } + + err = schemaManager.Upgrade() + if err != nil { + return nil, err + } + + database, err := repository.Connect(ctx, endpoint) + if err != nil { + return nil, err + } + + return database, nil +} + +func populateDatabase(ctx context.Context, database *repository.Database) (err error) { + application := &model.Application{ + ContractAddress: appAddress, + TemplateHash: [32]byte{}, + SnapshotURI: "invalid", + LastProcessedBlock: 0, + EpochLength: 0, + Status: "RUNNING", + } + err = database.InsertApplication(ctx, application) + if err != nil { + return + } + + inputs := []*model.Input{{ + CompletionStatus: model.InputStatusAccepted, + RawData: []byte("first input"), + AppAddress: appAddress, + }, { + CompletionStatus: model.InputStatusNone, + RawData: []byte("second input"), + AppAddress: appAddress, + }, { + CompletionStatus: model.InputStatusNone, + RawData: []byte("third input"), + AppAddress: appAddress, + }} + + for i, input := range inputs { + input.Index = uint64(i) + input.BlockNumber = uint64(i) + input.RawData, err = rollupsmachine.Input{Data: input.RawData}.Encode() + if err != nil { + return + } + err = database.InsertInput(ctx, input) + if err != nil { + return + } + } + + return +}