From f92402d2f9f54de43b4c5738e4828a7003c96129 Mon Sep 17 00:00:00 2001 From: colinlyguo Date: Thu, 4 Jul 2024 01:47:58 +0800 Subject: [PATCH] add bundle orm tests --- database/migrate/migrations/00021_bundle.sql | 2 +- rollup/internal/orm/bundle.go | 21 +-- rollup/internal/orm/orm_test.go | 157 +++++++++++++++++-- 3 files changed, 155 insertions(+), 25 deletions(-) diff --git a/database/migrate/migrations/00021_bundle.sql b/database/migrate/migrations/00021_bundle.sql index eea42e48ef..0122b04871 100644 --- a/database/migrate/migrations/00021_bundle.sql +++ b/database/migrate/migrations/00021_bundle.sql @@ -3,7 +3,7 @@ CREATE TABLE bundle ( index BIGSERIAL PRIMARY KEY, - hash VARCHAR NOT NULL, -- Not part of DA hash, used for SQL query consistency and ease of use, derived using keccak256(concat(start_batch_hash, end_batch_hash)). + hash VARCHAR NOT NULL, -- Not part of DA hash, used for SQL query consistency and ease of use, derived using keccak256(concat(start_batch_hash_bytes, end_batch_hash_bytes)). start_batch_index BIGINT NOT NULL, end_batch_index BIGINT NOT NULL, start_batch_hash VARCHAR NOT NULL, diff --git a/rollup/internal/orm/bundle.go b/rollup/internal/orm/bundle.go index 7c21fd976f..27a60429f0 100644 --- a/rollup/internal/orm/bundle.go +++ b/rollup/internal/orm/bundle.go @@ -9,6 +9,7 @@ import ( "time" "github.com/scroll-tech/da-codec/encoding" + "github.com/scroll-tech/go-ethereum/common" "github.com/scroll-tech/go-ethereum/crypto" "gorm.io/gorm" @@ -21,7 +22,7 @@ type Bundle struct { db *gorm.DB `gorm:"column:-"` // bundle - Index uint64 `json:"index" gorm:"column:index"` + Index uint64 `json:"index" gorm:"column:index;primaryKey"` Hash string `json:"hash" gorm:"column:hash"` StartBatchIndex uint64 `json:"start_batch_index" gorm:"column:start_batch_index"` EndBatchIndex uint64 `json:"end_batch_index" gorm:"column:end_batch_index"` @@ -131,21 +132,7 @@ func (o *Bundle) InsertBundle(ctx context.Context, batches []*Batch, codecVersio db = db.WithContext(ctx) db = db.Model(&Bundle{}) - startBytes, err := hex.DecodeString(batches[0].Hash) - if err != nil { - return nil, fmt.Errorf("Bundle.InsertBundle DecodeString error: %w, batch hash: %v", err, batches[0].Hash) - } - endBytes, err := hex.DecodeString(batches[len(batches)-1].Hash) - if err != nil { - return nil, fmt.Errorf("Bundle.InsertBundle DecodeString error: %w, batch hash: %v", err, batches[len(batches)-1].Hash) - } - - // Not part of DA hash, used for SQL query consistency and ease of use. - // Derived using keccak256(concat(start_batch_hash, end_batch_hash)). - bundleHash := hex.EncodeToString(crypto.Keccak256(append(startBytes, endBytes...))) - newBundle := Bundle{ - Hash: bundleHash, StartBatchHash: batches[0].Hash, StartBatchIndex: batches[0].Index, EndBatchHash: batches[len(batches)-1].Hash, @@ -156,6 +143,10 @@ func (o *Bundle) InsertBundle(ctx context.Context, batches []*Batch, codecVersio CodecVersion: int16(codecVersion), } + // Not part of DA hash, used for SQL query consistency and ease of use. + // Derived using keccak256(concat(start_batch_hash_bytes, end_batch_hash_bytes)). + newBundle.Hash = hex.EncodeToString(crypto.Keccak256(append(common.Hex2Bytes(newBundle.StartBatchHash[2:]), common.Hex2Bytes(newBundle.EndBatchHash[2:])...))) + if err := db.Create(&newBundle).Error; err != nil { return nil, fmt.Errorf("Bundle.InsertBundle Create error: %w, bundle hash: %v", err, newBundle.Hash) } diff --git a/rollup/internal/orm/orm_test.go b/rollup/internal/orm/orm_test.go index 220ad4b839..bd1e565e39 100644 --- a/rollup/internal/orm/orm_test.go +++ b/rollup/internal/orm/orm_test.go @@ -19,6 +19,7 @@ import ( "scroll-tech/common/testcontainers" "scroll-tech/common/types" + "scroll-tech/common/types/message" "scroll-tech/database/migrate" "scroll-tech/rollup/internal/utils" @@ -31,6 +32,7 @@ var ( l2BlockOrm *L2Block chunkOrm *Chunk batchOrm *Batch + bundleOrm *Bundle pendingTransactionOrm *PendingTransaction block1 *encoding.Block @@ -61,6 +63,7 @@ func setupEnv(t *testing.T) { assert.NoError(t, err) assert.NoError(t, migrate.ResetDB(sqlDB)) + bundleOrm = NewBundle(db) batchOrm = NewBatch(db) chunkOrm = NewChunk(db) l2BlockOrm = NewL2Block(db) @@ -269,10 +272,8 @@ func TestBatchOrm(t *testing.T) { assert.NoError(t, migrate.ResetDB(sqlDB)) batch := &encoding.Batch{ - Index: 0, - TotalL1MessagePoppedBefore: 0, - ParentBatchHash: common.Hash{}, - Chunks: []*encoding.Chunk{chunk1}, + Index: 0, + Chunks: []*encoding.Chunk{chunk1}, } batch1, err := batchOrm.InsertBatch(context.Background(), batch, codecVersion, utils.BatchMetrics{}) assert.NoError(t, err) @@ -299,14 +300,11 @@ func TestBatchOrm(t *testing.T) { assert.NoError(t, createErr) batchHash1 = daBatch1.Hash().Hex() } - assert.Equal(t, hash1, batchHash1) batch = &encoding.Batch{ - Index: 1, - TotalL1MessagePoppedBefore: 0, - ParentBatchHash: common.Hash{}, - Chunks: []*encoding.Chunk{chunk2}, + Index: 1, + Chunks: []*encoding.Chunk{chunk2}, } batch2, err := batchOrm.InsertBatch(context.Background(), batch, codecVersion, utils.BatchMetrics{}) assert.NoError(t, err) @@ -382,7 +380,148 @@ func TestBatchOrm(t *testing.T) { assert.NotNil(t, updatedBatch) assert.Equal(t, "finalizeTxHash", updatedBatch.FinalizeTxHash) assert.Equal(t, types.RollupFinalizeFailed, types.RollupStatus(updatedBatch.RollupStatus)) + + batches, err := batchOrm.GetBatchesGEIndex(context.Background(), 0, 0) + assert.NoError(t, err) + assert.Equal(t, 2, len(batches)) + assert.Equal(t, batchHash1, batches[0].Hash) + assert.Equal(t, batchHash2, batches[1].Hash) + + batches, err = batchOrm.GetBatchesGEIndex(context.Background(), 0, 1) + assert.NoError(t, err) + assert.Equal(t, 1, len(batches)) + assert.Equal(t, batchHash1, batches[0].Hash) + + batches, err = batchOrm.GetBatchesGEIndex(context.Background(), 1, 0) + assert.NoError(t, err) + assert.Equal(t, 1, len(batches)) + assert.Equal(t, batchHash2, batches[0].Hash) + + err = batchOrm.UpdateBundleHashInRange(context.Background(), 0, 0, "test hash") + assert.NoError(t, err) + + err = batchOrm.UpdateProvingStatusByBundleHash(context.Background(), "test hash", types.ProvingTaskFailed) + assert.NoError(t, err) + + err = batchOrm.UpdateRollupStatusByBundleHash(context.Background(), "test hash", types.RollupFinalizeFailed) + assert.NoError(t, err) + + batches, err = batchOrm.GetBatchesGEIndex(context.Background(), 0, 0) + assert.NoError(t, err) + assert.Equal(t, 2, len(batches)) + assert.Equal(t, batchHash1, batches[0].Hash) + assert.Equal(t, batchHash2, batches[1].Hash) + assert.Equal(t, types.ProvingStatus(batches[0].ProvingStatus), types.ProvingTaskFailed) + assert.Equal(t, types.RollupStatus(batches[0].RollupStatus), types.RollupFinalizeFailed) + } +} + +func TestBundleOrm(t *testing.T) { + sqlDB, err := db.DB() + assert.NoError(t, err) + assert.NoError(t, migrate.ResetDB(sqlDB)) + + chunk1 := &encoding.Chunk{Blocks: []*encoding.Block{block1}} + batch1 := &encoding.Batch{ + Index: 0, + Chunks: []*encoding.Chunk{chunk1}, + } + dbBatch1, err := batchOrm.InsertBatch(context.Background(), batch1, encoding.CodecV3, utils.BatchMetrics{}) + assert.NoError(t, err) + + chunk2 := &encoding.Chunk{Blocks: []*encoding.Block{block2}} + batch2 := &encoding.Batch{ + Index: 1, + Chunks: []*encoding.Chunk{chunk2}, } + dbBatch2, err := batchOrm.InsertBatch(context.Background(), batch2, encoding.CodecV3, utils.BatchMetrics{}) + assert.NoError(t, err) + + var bundle1 *Bundle + var bundle2 *Bundle + + t.Run("InsertBundle", func(t *testing.T) { + bundle1, err = bundleOrm.InsertBundle(context.Background(), []*Batch{dbBatch1}, encoding.CodecV3) + assert.NoError(t, err) + assert.NotNil(t, bundle1) + assert.Equal(t, uint64(0), bundle1.StartBatchIndex) + assert.Equal(t, uint64(0), bundle1.EndBatchIndex) + assert.Equal(t, dbBatch1.Hash, bundle1.StartBatchHash) + assert.Equal(t, dbBatch1.Hash, bundle1.EndBatchHash) + assert.Equal(t, encoding.CodecV3, encoding.CodecVersion(bundle1.CodecVersion)) + + bundle2, err = bundleOrm.InsertBundle(context.Background(), []*Batch{dbBatch2}, encoding.CodecV3) + assert.NoError(t, err) + assert.NotNil(t, bundle2) + assert.Equal(t, uint64(1), bundle2.StartBatchIndex) + assert.Equal(t, uint64(1), bundle2.EndBatchIndex) + assert.Equal(t, dbBatch2.Hash, bundle2.StartBatchHash) + assert.Equal(t, dbBatch2.Hash, bundle2.EndBatchHash) + assert.Equal(t, encoding.CodecV3, encoding.CodecVersion(bundle2.CodecVersion)) + }) + + t.Run("GetFirstUnbundledBatchIndex", func(t *testing.T) { + index, err := bundleOrm.GetFirstUnbundledBatchIndex(context.Background()) + assert.NoError(t, err) + assert.Equal(t, uint64(2), index) + }) + + t.Run("GetFirstPendingBundle", func(t *testing.T) { + bundle, err := bundleOrm.GetFirstPendingBundle(context.Background()) + assert.NoError(t, err) + assert.NotNil(t, bundle) + assert.Equal(t, int16(types.RollupPending), bundle.RollupStatus) + }) + + t.Run("UpdateFinalizeTxHashAndRollupStatus", func(t *testing.T) { + err := bundleOrm.UpdateFinalizeTxHashAndRollupStatus(context.Background(), bundle1.Hash, "0xabcd", types.RollupFinalized) + assert.NoError(t, err) + + pendingBundle, err := bundleOrm.GetFirstPendingBundle(context.Background()) + assert.NoError(t, err) + assert.Equal(t, uint64(2), pendingBundle.Index) + + var finalizedBundle Bundle + err = db.Where("hash = ?", bundle1.Hash).First(&finalizedBundle).Error + assert.NoError(t, err) + assert.Equal(t, "0xabcd", finalizedBundle.FinalizeTxHash) + assert.Equal(t, int16(types.RollupFinalized), finalizedBundle.RollupStatus) + assert.NotNil(t, finalizedBundle.FinalizedAt) + }) + + t.Run("UpdateProvingStatus", func(t *testing.T) { + err := bundleOrm.UpdateProvingStatus(context.Background(), bundle1.Hash, types.ProvingTaskAssigned) + assert.NoError(t, err) + + var bundle Bundle + err = db.Where("hash = ?", bundle1.Hash).First(&bundle).Error + assert.NoError(t, err) + assert.Equal(t, int16(types.ProvingTaskAssigned), bundle.ProvingStatus) + assert.NotNil(t, bundle.ProverAssignedAt) + + err = bundleOrm.UpdateProvingStatus(context.Background(), bundle1.Hash, types.ProvingTaskVerified) + assert.NoError(t, err) + + err = db.Where("hash = ?", bundle1.Hash).First(&bundle).Error + assert.NoError(t, err) + assert.Equal(t, int16(types.ProvingTaskVerified), bundle.ProvingStatus) + assert.NotNil(t, bundle.ProvedAt) + }) + + t.Run("GetVerifiedProofByHash", func(t *testing.T) { + proof := &message.BundleProof{ + Proof: []byte("test proof"), + } + proofBytes, err := json.Marshal(proof) + assert.NoError(t, err) + + err = db.Model(&Bundle{}).Where("hash = ?", bundle1.Hash).Update("proof", proofBytes).Error + assert.NoError(t, err) + + retrievedProof, err := bundleOrm.GetVerifiedProofByHash(context.Background(), bundle1.Hash) + assert.NoError(t, err) + assert.Equal(t, proof.Proof, retrievedProof.Proof) + }) } func TestPendingTransactionOrm(t *testing.T) {