Skip to content

Commit

Permalink
feat(advancer): add the advancer's repository
Browse files Browse the repository at this point in the history
  • Loading branch information
renan061 committed Aug 27, 2024
1 parent 38f9595 commit e197c24
Show file tree
Hide file tree
Showing 16 changed files with 840 additions and 163 deletions.
13 changes: 6 additions & 7 deletions cmd/cartesi-rollups-cli/root/db/check/check.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
package check

import (
"fmt"
"log/slog"

"github.com/cartesi/rollups-node/cmd/cartesi-rollups-cli/root/common"
"github.com/cartesi/rollups-node/internal/repository"
"github.com/cartesi/rollups-node/internal/repository/schema"
"github.com/spf13/cobra"
)

Expand All @@ -17,13 +17,12 @@ var Cmd = &cobra.Command{
}

func run(cmd *cobra.Command, args []string) {

schemaManager, err := repository.NewSchemaManager(common.PostgresEndpoint)
schema, err := schema.New(common.PostgresEndpoint)
cobra.CheckErr(err)
defer schemaManager.Close()
defer schema.Close()

err = schemaManager.ValidateSchemaVersion()
version, err := schema.ValidateVersion()
cobra.CheckErr(err)

fmt.Printf("Database Schema is at the correct version: %d\n", repository.EXPECTED_VERSION)
slog.Info("Database Schema is at the correct version.", "version", version)
}
19 changes: 6 additions & 13 deletions cmd/cartesi-rollups-cli/root/db/upgrade/upgrade.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
package upgrade

import (
"fmt"
"log/slog"

"github.com/cartesi/rollups-node/cmd/cartesi-rollups-cli/root/common"
"github.com/cartesi/rollups-node/internal/repository"
"github.com/cartesi/rollups-node/internal/repository/schema"
"github.com/spf13/cobra"
)

Expand All @@ -18,21 +17,15 @@ var Cmd = &cobra.Command{
}

func run(cmd *cobra.Command, args []string) {

schemaManager, err := repository.NewSchemaManager(common.PostgresEndpoint)
schema, err := schema.New(common.PostgresEndpoint)
cobra.CheckErr(err)
defer schemaManager.Close()
defer schema.Close()

err = schemaManager.Upgrade()
err = schema.Up()
cobra.CheckErr(err)

version, err := schemaManager.GetVersion()
version, err := schema.ValidateVersion()
cobra.CheckErr(err)

if repository.EXPECTED_VERSION != version {
slog.Warn("Current version is different to expected one")
}

fmt.Printf("Database Schema successfully Updated. Current version is %d\n", version)

slog.Info("Database Schema successfully Updated.", "version", version)
}
16 changes: 10 additions & 6 deletions internal/node/advancer/advancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,14 @@ func (advancer *Advancer) Step(ctx context.Context) error {
}
}

// Updates the status of the epochs.
for _, app := range apps {
err := advancer.repository.UpdateEpochs(ctx, app)
if err != nil {
return err
}
}

return nil
}

Expand Down Expand Up @@ -99,11 +107,7 @@ func (advancer *Advancer) process(ctx context.Context, app Address, inputs []*In
}
}

// Updates the status of the epochs based on the last processed input.
lastInput := inputs[len(inputs)-1]
err := advancer.repository.UpdateEpochs(ctx, app, lastInput)

return err
return nil
}

// ------------------------------------------------------------------------------------------------
Expand All @@ -114,7 +118,7 @@ type Repository interface {

StoreAdvanceResult(context.Context, *Input, *nodemachine.AdvanceResult) error

UpdateEpochs(_ context.Context, app Address, lastInput *Input) error
UpdateEpochs(_ context.Context, app Address) error
}

// A map of application addresses to machines.
Expand Down
29 changes: 9 additions & 20 deletions internal/node/startup/startup.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/cartesi/rollups-node/internal/node/config"
"github.com/cartesi/rollups-node/internal/node/model"
"github.com/cartesi/rollups-node/internal/repository"
"github.com/cartesi/rollups-node/internal/repository/schema"
"github.com/ethereum/go-ethereum/common"
"github.com/jackc/pgx/v5"
"github.com/lmittmann/tint"
Expand All @@ -20,31 +21,19 @@ import (

// Validates the Node Database Schema Version
func ValidateSchema(config config.NodeConfig) error {
var (
schemaManager *repository.SchemaManager
err error
)

if !config.PostgresSslMode {
schemaManager, err = repository.NewSchemaManager(
fmt.Sprintf("%v?sslmode=disable", config.PostgresEndpoint.Value))
if err != nil {
return err
}
} else {
schemaManager, err = repository.NewSchemaManager(config.PostgresEndpoint.Value)
if err != nil {
return err
}
endpoint := config.PostgresEndpoint.Value
if config.PostgresSslMode {
endpoint += "?sslmode=disable"
}
defer schemaManager.Close()
err = schemaManager.ValidateSchemaVersion()

schema, err := schema.New(endpoint)
if err != nil {
return err
}
defer schema.Close()

return nil

_, err = schema.ValidateVersion()
return err
}

// Configure the node logs
Expand Down
221 changes: 221 additions & 0 deletions internal/repository/advancer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
// (c) Cartesi and individual authors (see AUTHORS)
// SPDX-License-Identifier: Apache-2.0 (see LICENSE)

package repository

import (
"context"
"errors"
"fmt"
"strings"

. "github.com/cartesi/rollups-node/internal/node/model"
"github.com/cartesi/rollups-node/internal/nodemachine"
"github.com/jackc/pgx/v5"
)

var ErrAdvancerRepository = errors.New("advancer repository error")

type AdvancerRepository struct{ *Database }

func (repo *AdvancerRepository) GetInputs(
ctx context.Context,
apps []Address,
) (map[Address][]*Input, error) {
result := map[Address][]*Input{}
if len(apps) == 0 {
return result, nil
}

query := fmt.Sprintf(`
SELECT id, application_address, raw_data
FROM input
WHERE status = 'NONE'
AND application_address IN %s
ORDER BY index ASC, application_address
`, toSqlIn(apps)) // NOTE: not sanitized
rows, err := repo.db.Query(ctx, query)
if err != nil {
return nil, fmt.Errorf("%w (failed querying inputs): %w", ErrAdvancerRepository, err)
}

var input Input
scans := []any{&input.Id, &input.AppAddress, &input.RawData}
_, err = pgx.ForEachRow(rows, scans, func() error {
input := input
if _, ok := result[input.AppAddress]; ok { //nolint:gosimple
result[input.AppAddress] = append(result[input.AppAddress], &input)
} else {
result[input.AppAddress] = []*Input{&input}
}
return nil
})
if err != nil {
return nil, fmt.Errorf("%w (failed reading input rows): %w", ErrAdvancerRepository, err)
}

return result, nil
}

func (repo *AdvancerRepository) StoreResults(
ctx context.Context,
input *Input,
res *nodemachine.AdvanceResult,
) error {
tx, err := repo.db.Begin(ctx)
if err != nil {
return errors.Join(ErrBeginTx, err)
}

// Inserts the outputs.
nextOutputIndex, err := repo.getNextIndex(ctx, tx, "output", input.AppAddress)
if err != nil {
return err
}
err = repo.insert(ctx, tx, "output", res.Outputs, input.Id, nextOutputIndex)
if err != nil {
return err
}

// Inserts the reports.
nextReportIndex, err := repo.getNextIndex(ctx, tx, "report", input.AppAddress)
if err != nil {
return err
}
err = repo.insert(ctx, tx, "report", res.Reports, input.Id, nextReportIndex)
if err != nil {
return err
}

// Updates the input's status.
err = repo.updateInput(ctx, tx, input.Id, res.Status, res.OutputsHash, res.MachineHash)
if err != nil {
return err
}

err = tx.Commit(ctx)
if err != nil {
return errors.Join(ErrCommitTx, err, tx.Rollback(ctx))
}

return nil
}

func (repo *AdvancerRepository) UpdateEpochs(ctx context.Context, app Address) error {
query := `
UPDATE epoch
SET status = 'PROCESSED_ALL_INPUTS'
WHERE id IN ((
SELECT DISTINCT epoch.id
FROM epoch INNER JOIN input ON (epoch.id = input.epoch_id)
WHERE epoch.application_address = @applicationAddress
AND epoch.status = 'CLOSED'
AND input.status != 'NONE'
) EXCEPT (
SELECT DISTINCT epoch.id
FROM epoch INNER JOIN input ON (epoch.id = input.epoch_id)
WHERE epoch.application_address = @applicationAddress
AND epoch.status = 'CLOSED'
AND input.status = 'NONE'))
`
args := pgx.NamedArgs{"applicationAddress": app}
_, err := repo.db.Exec(ctx, query, args)
if err != nil {
return errors.Join(ErrUpdateRow, err)
}
return nil
}

// ------------------------------------------------------------------------------------------------

func (_ *AdvancerRepository) getNextIndex(
ctx context.Context,
tx pgx.Tx,
tableName string,
appAddress Address,
) (uint64, error) {
var nextIndex uint64
query := fmt.Sprintf(`
SELECT COALESCE(MAX(%s.index) + 1, 0)
FROM input INNER JOIN %s ON input.id = %s.input_id
WHERE input.status = 'ACCEPTED'
AND input.application_address = $1
`, tableName, tableName, tableName)
err := tx.QueryRow(ctx, query, appAddress).Scan(&nextIndex)
if err != nil {
err = fmt.Errorf("failed to get the next %s index: %w", tableName, err)
return 0, errors.Join(err, tx.Rollback(ctx))
}
return nextIndex, nil
}

func (_ *AdvancerRepository) insert(
ctx context.Context,
tx pgx.Tx,
tableName string,
dataArray [][]byte,
inputId uint64,
nextIndex uint64,
) error {
lenOutputs := int64(len(dataArray))
if lenOutputs < 1 {
return nil
}

rows := [][]any{}
for i, data := range dataArray {
rows = append(rows, []any{inputId, nextIndex + uint64(i), data})
}

count, err := tx.CopyFrom(
ctx,
pgx.Identifier{tableName},
[]string{"input_id", "index", "raw_data"},
pgx.CopyFromRows(rows),
)
if err != nil {
return errors.Join(ErrCopyFrom, err, tx.Rollback(ctx))
}
if lenOutputs != count {
err := fmt.Errorf("not all %ss were inserted (%d != %d)", tableName, lenOutputs, count)
return errors.Join(err, tx.Rollback(ctx))
}

return nil
}

func (_ *AdvancerRepository) updateInput(
ctx context.Context,
tx pgx.Tx,
inputId uint64,
status InputCompletionStatus,
outputsHash Hash,
machineHash *Hash,
) error {
query := `
UPDATE input
SET (status, outputs_hash, machine_hash) = (@status, @outputsHash, @machineHash)
WHERE id = @id
`
args := pgx.NamedArgs{
"status": status,
"outputsHash": outputsHash,
"machineHash": machineHash,
"id": inputId,
}
_, err := tx.Exec(ctx, query, args)
if err != nil {
return errors.Join(ErrUpdateRow, err, tx.Rollback(ctx))
}
return nil
}

// ------------------------------------------------------------------------------------------------

func toSqlIn[T fmt.Stringer](a []T) string {
s := []string{}
for _, x := range a {
s = append(s, fmt.Sprintf("'\\x%s'", x.String()[2:]))
}
return fmt.Sprintf("(%s)", strings.Join(s, ", "))
}
Loading

0 comments on commit e197c24

Please sign in to comment.