Skip to content

Commit

Permalink
Merge pull request #293 from SiaFoundation/nate/verify-sector-consist…
Browse files Browse the repository at this point in the history
…ency

Sector change consistency
  • Loading branch information
n8maninger authored Feb 6, 2024
2 parents ee849f6 + 4646871 commit 8feb598
Show file tree
Hide file tree
Showing 5 changed files with 158 additions and 114 deletions.
12 changes: 7 additions & 5 deletions host/contracts/contracts.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,12 +332,14 @@ func (cu *ContractUpdater) Commit(revision SignedRevision, usage Usage) error {
start := time.Now()
// revise the contract
err := cu.store.ReviseContract(revision, cu.oldRoots, usage, cu.sectorActions)
if err == nil {
// clear the committed sector actions
cu.sectorActions = cu.sectorActions[:0]
if err != nil {
return err
}

// clear the committed sector actions
cu.sectorActions = cu.sectorActions[:0]
// update the roots cache
cu.rootsCache.Add(revision.Revision.ParentID, cu.sectorRoots[:])
cu.rootsCache.Add(revision.Revision.ParentID, append([]types.Hash256(nil), cu.sectorRoots...))
cu.log.Debug("contract update committed", zap.String("contractID", revision.Revision.ParentID.String()), zap.Uint64("revision", revision.Revision.RevisionNumber), zap.Duration("elapsed", time.Since(start)))
return err
return nil
}
200 changes: 133 additions & 67 deletions persist/sqlite/contracts.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ type (
contractSectorRootRef struct {
dbID int64
sectorID int64
root types.Hash256
}
)

Expand Down Expand Up @@ -260,7 +261,7 @@ func (s *Store) RenewContract(renewal contracts.SignedRevision, clearing contrac
}

// ReviseContract atomically updates a contract's revision and sectors
func (s *Store) ReviseContract(revision contracts.SignedRevision, oldRoots []types.Hash256, usage contracts.Usage, sectorChanges []contracts.SectorChange) error {
func (s *Store) ReviseContract(revision contracts.SignedRevision, roots []types.Hash256, usage contracts.Usage, sectorChanges []contracts.SectorChange) error {
return s.transaction(func(tx txn) error {
// revise the contract
contractID, err := reviseContract(tx, revision)
Expand All @@ -277,31 +278,57 @@ func (s *Store) ReviseContract(revision contracts.SignedRevision, oldRoots []typ
}

// update the sector roots
sectors := uint64(len(oldRoots))
sectors := uint64(len(roots))
roots := append([]types.Hash256(nil), roots...)
for _, change := range sectorChanges {
switch change.Action {
case contracts.SectorActionAppend:
if err := appendSector(tx, contractID, change.Root, sectors); err != nil {
return fmt.Errorf("failed to append sector: %w", err)
}
sectors++
roots = append(roots, change.Root)
case contracts.SectorActionTrim:
if sectors < change.A {
return fmt.Errorf("cannot trim %v sectors from contract with %v sectors", change.A, sectors)
}

if err := trimSectors(tx, contractID, change.A, s.log); err != nil {
trimmed, err := trimSectors(tx, contractID, change.A, s.log)
if err != nil {
return fmt.Errorf("failed to trim sectors: %w", err)
}
sectors -= change.A
removed := roots[len(roots)-int(change.A):]
for _, root := range removed {
if !trimmed[root] {
return fmt.Errorf("inconsistent sector trim: expected %s to be trimmed", root)
}
}
roots = roots[:len(roots)-int(change.A)]
case contracts.SectorActionUpdate:
if err := updateSector(tx, contractID, change.Root, change.A); err != nil {
oldRoot, err := updateSector(tx, contractID, change.Root, change.A)
if err != nil {
return fmt.Errorf("failed to update sector: %w", err)
} else if roots[change.A] != oldRoot {
return fmt.Errorf("inconsistent sector update (%d): expected old sector %s, got %s", change.A, roots[change.A], oldRoot)
}
roots[change.A] = change.Root
case contracts.SectorActionSwap:
if err := swapSectors(tx, contractID, change.A, change.B); err != nil {
if change.A > change.B {
change.A, change.B = change.B, change.A
}

swapped, err := swapSectors(tx, contractID, change.A, change.B)
if err != nil {
return fmt.Errorf("failed to swap sectors: %w", err)
}
oldA, oldB := roots[change.A], roots[change.B]
for root := range swapped {
if root != oldA && root != oldB {
return fmt.Errorf("inconsistent sector swap: expected %s or %s, got %s", oldA, oldB, root)
}
}
roots[change.A], roots[change.B] = roots[change.B], roots[change.A]
}
}
return nil
Expand Down Expand Up @@ -545,93 +572,105 @@ func appendSector(tx txn, contractID int64, root types.Hash256, index uint64) er
return nil
}

func updateSector(tx txn, contractID int64, root types.Hash256, index uint64) error {
var oldSectorID int64
if err := tx.QueryRow(`SELECT sector_id FROM contract_sector_roots WHERE contract_id=$1 AND root_index=$2`, contractID, index).Scan(&oldSectorID); err != nil {
return fmt.Errorf("failed to get old sector id: %w", err)
func updateSector(tx txn, contractID int64, root types.Hash256, index uint64) (types.Hash256, error) {
row := tx.QueryRow(`SELECT csr.id, csr.sector_id, ss.sector_root
FROM contract_sector_roots csr
INNER JOIN stored_sectors ss ON (csr.sector_id = ss.id)
WHERE contract_id=$1 AND root_index=$2`, contractID, index)
ref, err := scanContractSectorRootRef(row)
if err != nil {
return types.Hash256{}, fmt.Errorf("failed to get old sector id: %w", err)
}

const query = `WITH sector AS (
SELECT id FROM stored_sectors WHERE sector_root=$1
)
UPDATE contract_sector_roots
SET sector_id=sector.id
FROM sector
WHERE contract_id=$2 AND root_index=$3
RETURNING sector_id;`
// update the sector ID
var newSectorID int64
err := tx.QueryRow(query, sqlHash256(root), contractID, index).Scan(&newSectorID)
err = tx.QueryRow(`WITH sector AS (
SELECT id FROM stored_sectors WHERE sector_root=$1
)
UPDATE contract_sector_roots
SET sector_id=sector.id
FROM sector
WHERE contract_sector_roots.id=$2
RETURNING sector_id;`, sqlHash256(root), ref.dbID).Scan(&newSectorID)
if err != nil {
return err
} else if err := pruneSectorRef(tx, oldSectorID); err != nil {
return fmt.Errorf("failed to prune sector ref: %w", err)
return types.Hash256{}, fmt.Errorf("failed to update sector ID: %w", err)
}
return nil
// prune the old sector ID
if _, err := pruneSectorRef(tx, ref.sectorID); err != nil {
return types.Hash256{}, fmt.Errorf("failed to prune old sector: %w", err)
}
return ref.root, nil
}

func swapSectors(tx txn, contractID int64, i, j uint64) error {
func swapSectors(tx txn, contractID int64, i, j uint64) (map[types.Hash256]bool, error) {
if i == j {
return nil
return nil, nil
}

var records []contractSectorRootRef
rows, err := tx.Query(`SELECT id, sector_id FROM contract_sector_roots WHERE contract_id=$1 AND root_index IN ($2, $3);`, contractID, i, j)
rows, err := tx.Query(`SELECT csr.id, csr.sector_id, ss.sector_root
FROM contract_sector_roots csr
INNER JOIN stored_sectors ss ON (ss.id = csr.sector_id)
WHERE contract_id=$1 AND root_index IN ($2, $3)
ORDER BY root_index ASC;`, contractID, i, j)
if err != nil {
return fmt.Errorf("failed to query sector IDs: %w", err)
return nil, fmt.Errorf("failed to query sector IDs: %w", err)
}
defer rows.Close()
for rows.Next() {
var record contractSectorRootRef
if err := rows.Scan(&record.dbID, &record.sectorID); err != nil {
return fmt.Errorf("failed to scan sector ID: %w", err)
ref, err := scanContractSectorRootRef(rows)
if err != nil {
return nil, fmt.Errorf("failed to scan sector ref: %w", err)
}
records = append(records, record)
records = append(records, ref)
}

if len(records) != 2 {
return errors.New("failed to find both sectors")
return nil, errors.New("failed to find both sectors")
}

stmt, err := tx.Prepare(`UPDATE contract_sector_roots SET sector_id=$1 WHERE id=$2`)
stmt, err := tx.Prepare(`UPDATE contract_sector_roots SET sector_id=$1 WHERE id=$2 RETURNING sector_id;`)
if err != nil {
return fmt.Errorf("failed to prepare update statement: %w", err)
return nil, fmt.Errorf("failed to prepare update statement: %w", err)
}
defer stmt.Close()

res, err := stmt.Exec(records[1].sectorID, records[0].dbID)
var newSectorID int64
err = stmt.QueryRow(records[1].sectorID, records[0].dbID).Scan(&newSectorID)
if err != nil {
return fmt.Errorf("failed to update sector ID: %w", err)
} else if rows, err := res.RowsAffected(); err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
} else if rows != 1 {
return fmt.Errorf("expected 1 row affected, got %v", rows)
return nil, fmt.Errorf("failed to update sector ID: %w", err)
} else if newSectorID != records[1].sectorID {
return nil, fmt.Errorf("expected sector ID %v, got %v", records[0].sectorID, newSectorID)
}

res, err = stmt.Exec(records[0].sectorID, records[1].dbID)
err = stmt.QueryRow(records[0].sectorID, records[1].dbID).Scan(&newSectorID)
if err != nil {
return fmt.Errorf("failed to update sector ID: %w", err)
} else if rows, err := res.RowsAffected(); err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
} else if rows != 1 {
return fmt.Errorf("expected 1 row affected, got %v", rows)
return nil, fmt.Errorf("failed to update sector ID: %w", err)
} else if newSectorID != records[0].sectorID {
return nil, fmt.Errorf("expected sector ID %v, got %v", records[0].sectorID, newSectorID)
}

return nil
return map[types.Hash256]bool{
records[0].root: true,
records[1].root: true,
}, nil
}

// lastContractSectors returns the last n sector IDs for a contract.
func lastContractSectors(tx txn, contractID int64, n uint64) (roots []contractSectorRootRef, err error) {
const query = `SELECT id, sector_id FROM contract_sector_roots WHERE contract_id=$1 ORDER BY root_index DESC LIMIT $2;`
const query = `SELECT csr.id, csr.sector_id, ss.sector_root FROM contract_sector_roots csr
INNER JOIN stored_sectors ss ON (csr.sector_id=ss.id)
WHERE contract_id=$1 ORDER BY root_index DESC LIMIT $2;`
rows, err := tx.Query(query, contractID, n)
if err != nil {
return nil, err
}
defer rows.Close()

for rows.Next() {
var ref contractSectorRootRef
if err := rows.Scan(&ref.dbID, &ref.sectorID); err != nil {
return nil, err
ref, err := scanContractSectorRootRef(rows)
if err != nil {
return nil, fmt.Errorf("failed to scan sector ref: %w", err)
}
roots = append(roots, ref)
}
Expand All @@ -647,14 +686,26 @@ func deleteContractSectors(tx txn, refs []contractSectorRootRef) (int, error) {
}

// delete the sector roots
query := `DELETE FROM contract_sector_roots WHERE id IN (` + queryPlaceHolders(len(rootIDs)) + `);`
res, err := tx.Exec(query, queryArgs(rootIDs)...)
query := `DELETE FROM contract_sector_roots WHERE id IN (` + queryPlaceHolders(len(rootIDs)) + `) RETURNING id;`
rows, err := tx.Query(query, queryArgs(rootIDs)...)
if err != nil {
return 0, fmt.Errorf("failed to delete sectors: %w", err)
} else if rows, err := res.RowsAffected(); err != nil {
return 0, fmt.Errorf("failed to get rows affected: %w", err)
} else if rows != int64(len(refs)) {
return 0, fmt.Errorf("failed to delete all sectors: %w", err)
}
deleted := make(map[int64]bool)
for rows.Next() {
var id int64
if err := rows.Scan(&id); err != nil {
return 0, fmt.Errorf("failed to scan deleted sector: %w", err)
}
deleted[id] = true
}
if len(deleted) != len(rootIDs) {
return 0, errors.New("failed to delete all sectors")
}
for _, rootID := range rootIDs {
if !deleted[rootID] {
return 0, errors.New("failed to delete all sectors")
}
}

// decrement the contract metrics
Expand All @@ -665,25 +716,30 @@ func deleteContractSectors(tx txn, refs []contractSectorRootRef) (int, error) {
// attempt to prune the deleted sectors
var pruned int
for _, ref := range refs {
if err := pruneSectorRef(tx, ref.sectorID); errors.Is(err, errSectorHasRefs) {
continue
} else if err != nil {
deleted, err := pruneSectorRef(tx, ref.sectorID)
if err != nil {
return 0, fmt.Errorf("failed to prune sector ref: %w", err)
} else if deleted {
pruned++
}
pruned++
}
return pruned, nil
}

// trimSectors deletes the last n sector roots for a contract.
func trimSectors(tx txn, contractID int64, n uint64, log *zap.Logger) error {
func trimSectors(tx txn, contractID int64, n uint64, log *zap.Logger) (map[types.Hash256]bool, error) {
refs, err := lastContractSectors(tx, contractID, n)
if err != nil {
return fmt.Errorf("failed to get sector IDs: %w", err)
return nil, fmt.Errorf("failed to get sector IDs: %w", err)
} else if _, err = deleteContractSectors(tx, refs); err != nil {
return nil, fmt.Errorf("failed to delete sectors: %w", err)
}

_, err = deleteContractSectors(tx, refs)
return err
roots := make(map[types.Hash256]bool)
for _, ref := range refs {
roots[ref.root] = true
}
return roots, nil
}

// clearContract clears a contract and returns its ID
Expand Down Expand Up @@ -1159,8 +1215,18 @@ func setContractStatus(tx txn, id types.FileContractID, status contracts.Contrac
return nil
}

func scanContractSectorRef(s scanner) (ref contractSectorRef, err error) {
err = s.Scan(&ref.ID, (*sqlHash256)(&ref.ContractID), &ref.SectorID)
return
}

func scanContractSectorRootRef(s scanner) (ref contractSectorRootRef, err error) {
err = s.Scan(&ref.dbID, &ref.sectorID, (*sqlHash256)(&ref.root))
return
}

func expiredContractSectors(tx txn, height uint64, batchSize int64) (sectors []contractSectorRef, _ error) {
const query = `SELECT csr.id, c.contract_id, csr.sector_id FROM contract_sector_roots csr
const query = `SELECT csr.id, c.contract_id, csr.sector_id FROM contract_sector_roots csr
INNER JOIN contracts c ON (csr.contract_id=c.id)
-- past proof window or not confirmed and past the rebroadcast height
WHERE c.window_end < $1 OR c.contract_status=$2 LIMIT $3;`
Expand All @@ -1170,8 +1236,8 @@ WHERE c.window_end < $1 OR c.contract_status=$2 LIMIT $3;`
}
defer rows.Close()
for rows.Next() {
var ref contractSectorRef
if err := rows.Scan(&ref.ID, (*sqlHash256)(&ref.ContractID), &ref.SectorID); err != nil {
ref, err := scanContractSectorRef(rows)
if err != nil {
return nil, fmt.Errorf("failed to scan expired contract: %w", err)
}
sectors = append(sectors, ref)
Expand Down
20 changes: 1 addition & 19 deletions persist/sqlite/contracts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,24 +37,6 @@ func rootsEqual(a, b []types.Hash256) error {
return nil
}

func runRevision(db *Store, revision contracts.SignedRevision, roots []types.Hash256, changes []contracts.SectorChange) error {
for _, change := range changes {
switch change.Action {
// store a sector in the database for the append or update actions
case contracts.SectorActionAppend, contracts.SectorActionUpdate:
root := frand.Entropy256()
release, err := db.StoreSector(root, func(loc storage.SectorLocation, exists bool) error { return nil })
if err != nil {
return fmt.Errorf("failed to store sector: %w", err)
}
defer release()
change.Root = root
}
}

return db.ReviseContract(revision, roots, contracts.Usage{}, changes)
}

func TestReviseContract(t *testing.T) {
log := zaptest.NewLogger(t)
db, err := OpenDatabase(filepath.Join(t.TempDir(), "test.db"), log)
Expand Down Expand Up @@ -302,7 +284,7 @@ func TestReviseContract(t *testing.T) {
}
}

if err := runRevision(db, contract, oldRoots, test.changes); err != nil {
if err := db.ReviseContract(contract, oldRoots, contracts.Usage{}, test.changes); err != nil {
if test.errors {
t.Log("received error:", err)
return
Expand Down
Loading

0 comments on commit 8feb598

Please sign in to comment.