From 80d231f733e9dd8ca166c3d670470ed9a1c165d9 Mon Sep 17 00:00:00 2001 From: Ryan Date: Wed, 26 Oct 2022 09:58:17 +0200 Subject: [PATCH] Hierarchical tests in extendeddatacrossword_test and datasquare_test (#96) * Hierarchical tests in extendeddatacrossword_test and datasquare_test * Better test separation, looping over testing structs * fixing expected values in TestNewDataSquare --- datasquare_test.go | 100 ++++++----- extendeddatacrossword.go | 4 +- extendeddatacrossword_test.go | 314 +++++++++++++++++++++------------- 3 files changed, 242 insertions(+), 176 deletions(-) diff --git a/datasquare_test.go b/datasquare_test.go index 7f48a70..e444a41 100644 --- a/datasquare_test.go +++ b/datasquare_test.go @@ -9,30 +9,43 @@ import ( ) func TestNewDataSquare(t *testing.T) { - result, err := newDataSquare([][]byte{{1, 2}}, NewDefaultTree) - if err != nil { - panic(err) - } - if !reflect.DeepEqual(result.squareRow, [][][]byte{{{1, 2}}}) { - t.Errorf("newDataSquare failed for 1x1 square") - } - - result, err = newDataSquare([][]byte{{1, 2}, {3, 4}, {5, 6}, {7, 8}}, NewDefaultTree) - if err != nil { - panic(err) - } - if !reflect.DeepEqual(result.squareRow, [][][]byte{{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}) { - t.Errorf("newDataSquare failed for 2x2 square") - } - - _, err = newDataSquare([][]byte{{1, 2}, {3, 4}, {5, 6}}, NewDefaultTree) - if err == nil { - t.Errorf("newDataSquare failed; inconsistent number of chunks accepted") + tests := []struct { + name string + cells [][]byte + expected [][][]byte + }{ + {"1x1", [][]byte{{1, 2}}, [][][]byte{{{1, 2}}}}, + {"2x2", [][]byte{{1, 2}, {3, 4}, {5, 6}, {7, 8}}, [][][]byte{{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + result, err := newDataSquare(test.cells, NewDefaultTree) + if err != nil { + panic(err) + } + if !reflect.DeepEqual(result.squareRow, test.expected) { + t.Errorf("newDataSquare failed for %v square", test.name) + } + }) } +} - _, err = newDataSquare([][]byte{{1, 2}, {3, 4}, {5, 6}, {7}}, NewDefaultTree) - if err == nil { - t.Errorf("newDataSquare failed; chunks of unequal size accepted") +func TestInvalidDataSquareCreation(t *testing.T) { + tests := []struct { + name string + cells [][]byte + }{ + {"InconsistentChunkNumber", [][]byte{{1, 2}, {3, 4}, {5, 6}}}, + {"UnequalChunkSize", [][]byte{{1, 2}, {3, 4}, {5, 6}, {7}}}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + _, err := newDataSquare(test.cells, NewDefaultTree) + if err == nil { + t.Errorf("newDataSquare failed; chunks accepted with %v", test.name) + } + }) } } @@ -85,15 +98,6 @@ func TestExtendSquare(t *testing.T) { if err != nil { panic(err) } - err = ds.extendSquare(1, []byte{0}) - if err == nil { - t.Errorf("extendSquare failed; error not returned when filler chunk size does not match data square chunk size") - } - - ds, err = newDataSquare([][]byte{{1, 2}}, NewDefaultTree) - if err != nil { - panic(err) - } err = ds.extendSquare(1, []byte{0, 0}) if err != nil { panic(err) @@ -103,6 +107,17 @@ func TestExtendSquare(t *testing.T) { } } +func TestInvalidSquareExtension(t *testing.T) { + ds, err := newDataSquare([][]byte{{1, 2}}, NewDefaultTree) + if err != nil { + panic(err) + } + err = ds.extendSquare(1, []byte{0}) + if err == nil { + t.Errorf("extendSquare failed; error not returned when filler chunk size does not match data square chunk size") + } +} + func TestRoots(t *testing.T) { result, err := newDataSquare([][]byte{{1, 2}}, NewDefaultTree) if err != nil { @@ -143,14 +158,14 @@ func TestRootAPI(t *testing.T) { for i := uint(0); i < square.width; i++ { if !reflect.DeepEqual(square.getRowRoots()[i], square.getRowRoot(i)) { t.Errorf( - "Row root API results in different roots, expected %v go %v", + "Row root API results in different roots, expected %v got %v", square.getRowRoots()[i], square.getRowRoot(i), ) } if !reflect.DeepEqual(square.getColRoots()[i], square.getColRoot(i)) { t.Errorf( - "Column root API results in different roots, expected %v go %v", + "Column root API results in different roots, expected %v got %v", square.getColRoots()[i], square.getColRoot(i), ) @@ -167,6 +182,7 @@ func TestDefaultTreeProofs(t *testing.T) { if err != nil { t.Errorf("Got unexpected error: %v", err) } + if len(proof) != 2 { t.Errorf("computing row proof for (1, 1) in 2x2 square failed; expecting proof set of length 2") } @@ -176,24 +192,6 @@ func TestDefaultTreeProofs(t *testing.T) { if numLeaves != 2 { t.Errorf("computing row proof for (1, 1) in 2x2 square failed; expecting number of leaves to be 2") } - - result, err = newDataSquare([][]byte{{1, 2}, {3, 4}, {5, 6}, {7, 8}}, NewDefaultTree) - if err != nil { - panic(err) - } - _, proof, proofIndex, numLeaves, err = computeColProof(result, 1, 1) - if err != nil { - t.Errorf("Got unexpected error: %v", err) - } - if len(proof) != 2 { - t.Errorf("computing column proof for (1, 1) in 2x2 square failed; expecting proof set of length 2") - } - if proofIndex != 1 { - t.Errorf("computing column proof for (1, 1) in 2x2 square failed; expecting proof index of 1") - } - if numLeaves != 2 { - t.Errorf("computing column proof for (1, 1) in 2x2 square failed; expecting number of leaves to be 2") - } } func BenchmarkEDSRoots(b *testing.B) { diff --git a/extendeddatacrossword.go b/extendeddatacrossword.go index fa266c4..6dd811a 100644 --- a/extendeddatacrossword.go +++ b/extendeddatacrossword.go @@ -310,7 +310,7 @@ func (eds *ExtendedDataSquare) prerepairSanityCheck( rowIsComplete := noMissingData(eds.row(i)) colIsComplete := noMissingData(eds.col(i)) - // if there's no missing data in the this row + // if there's no missing data in this row if rowIsComplete { errs.Go(func() error { // ensure that the roots are equal @@ -321,7 +321,7 @@ func (eds *ExtendedDataSquare) prerepairSanityCheck( }) } - // if there's no missing data in the this col + // if there's no missing data in this col if colIsComplete { errs.Go(func() error { // ensure that the roots are equal diff --git a/extendeddatacrossword_test.go b/extendeddatacrossword_test.go index 3f863a0..7550691 100644 --- a/extendeddatacrossword_test.go +++ b/extendeddatacrossword_test.go @@ -19,143 +19,194 @@ type PseudoFraudProof struct { } func TestRepairExtendedDataSquare(t *testing.T) { - for codecName, codec := range codecs { + bufferSize := 64 + tests := []struct { + name string + // Size of each share, in bytes + shareSize int + codec Codec + }{ + {"leopard", bufferSize, NewLeoRSCodec()}, + {"infectiousGF8", bufferSize, NewRSGF8Codec()}, + } - bufferSize := 64 - ones := bytes.Repeat([]byte{1}, bufferSize) - twos := bytes.Repeat([]byte{2}, bufferSize) - threes := bytes.Repeat([]byte{3}, bufferSize) - fours := bytes.Repeat([]byte{4}, bufferSize) - - original, err := ComputeExtendedDataSquare([][]byte{ - ones, twos, - threes, fours, - }, codec, NewDefaultTree) - if err != nil { - panic(err) - } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + name, codec, shareSize := test.name, test.codec, test.shareSize + original := createTestEds(codec, shareSize) - rowRoots := original.RowRoots() - colRoots := original.ColRoots() + rowRoots := original.RowRoots() + colRoots := original.ColRoots() - flattened := original.Flattened() - flattened[0], flattened[2], flattened[3] = nil, nil, nil - flattened[4], flattened[5], flattened[6], flattened[7] = nil, nil, nil, nil - flattened[8], flattened[9], flattened[10] = nil, nil, nil - flattened[12], flattened[13] = nil, nil + // Verify that an EDS can be repaired after the maximum amount of erasures + t.Run("MaximumErasures", func(t *testing.T) { + flattened := original.Flattened() + flattened[0], flattened[2], flattened[3] = nil, nil, nil + flattened[4], flattened[5], flattened[6], flattened[7] = nil, nil, nil, nil + flattened[8], flattened[9], flattened[10] = nil, nil, nil + flattened[12], flattened[13] = nil, nil - // Re-import the data square. - eds, err := ImportExtendedDataSquare(flattened, codec, NewDefaultTree) - if err != nil { - t.Errorf("ImportExtendedDataSquare failed: %v", err) - } + // Re-import the data square. + eds, err := ImportExtendedDataSquare(flattened, codec, NewDefaultTree) + if err != nil { + t.Errorf("ImportExtendedDataSquare failed: %v", err) + } - err = eds.Repair(rowRoots, colRoots) - if err != nil { - t.Errorf("unexpected err while repairing data square: %v, codec: :%s", err, codecName) - } else { - assert.Equal(t, original.GetCell(0, 0), ones) - assert.Equal(t, original.GetCell(0, 1), twos) - assert.Equal(t, original.GetCell(1, 0), threes) - assert.Equal(t, original.GetCell(1, 1), fours) - } + err = eds.Repair(rowRoots, colRoots) + if err != nil { + t.Errorf("unexpected err while repairing data square: %v, codec: :%s", err, name) + } else { + assert.Equal(t, original.GetCell(0, 0), bytes.Repeat([]byte{1}, shareSize)) + assert.Equal(t, original.GetCell(0, 1), bytes.Repeat([]byte{2}, shareSize)) + assert.Equal(t, original.GetCell(1, 0), bytes.Repeat([]byte{3}, shareSize)) + assert.Equal(t, original.GetCell(1, 1), bytes.Repeat([]byte{4}, shareSize)) + } + }) - flattened = original.Flattened() - flattened[0], flattened[2], flattened[3] = nil, nil, nil - flattened[4], flattened[5], flattened[6], flattened[7] = nil, nil, nil, nil - flattened[8], flattened[9], flattened[10] = nil, nil, nil - flattened[12], flattened[13], flattened[14] = nil, nil, nil + // Verify that an EDS returns an error when there are too many erasures + t.Run("Unrepairable", func(t *testing.T) { + flattened := original.Flattened() + flattened[0], flattened[2], flattened[3] = nil, nil, nil + flattened[4], flattened[5], flattened[6], flattened[7] = nil, nil, nil, nil + flattened[8], flattened[9], flattened[10] = nil, nil, nil + flattened[12], flattened[13], flattened[14] = nil, nil, nil - // Re-import the data square. - eds, err = ImportExtendedDataSquare(flattened, codec, NewDefaultTree) - if err != nil { - t.Errorf("ImportExtendedDataSquare failed: %v", err) - } + // Re-import the data square. + eds, err := ImportExtendedDataSquare(flattened, codec, NewDefaultTree) + if err != nil { + t.Errorf("ImportExtendedDataSquare failed: %v", err) + } - err = eds.Repair(rowRoots, colRoots) - if err == nil { - t.Errorf("did not return an error on trying to repair an unrepairable square") - } - var corrupted ExtendedDataSquare - corrupted, err = original.deepCopy(codec) - if err != nil { - t.Fatalf("unexpected err while copying original data: %v, codec: :%s", err, codecName) - } - corruptChunk := bytes.Repeat([]byte{66}, bufferSize) - corrupted.setCell(0, 0, corruptChunk) - err = corrupted.Repair(rowRoots, colRoots) - if err == nil { - t.Errorf("did not return an error on trying to repair a square with bad roots") - } + err = eds.Repair(rowRoots, colRoots) + if err != ErrUnrepairableDataSquare { + t.Errorf("did not return an error on trying to repair an unrepairable square") + } + }) + }) + } +} - corrupted, err = original.deepCopy(codec) - if err != nil { - t.Fatalf("unexpected err while copying original data: %v, codec: :%s", err, codecName) - } - corrupted.setCell(0, 0, corruptChunk) - err = corrupted.Repair(corrupted.getRowRoots(), corrupted.getColRoots()) - var byzData *ErrByzantineData - if !errors.As(err, &byzData) { - // due to parallelisation, the ErrByzantineData axis may be either row or col - t.Errorf("did not return a ErrByzantineData for a bad row or col; got: %v", err) - } - // Construct the fraud proof - fraudProof := PseudoFraudProof{0, byzData.Index, byzData.Shares} - // Verify the fraud proof - // TODO in a real fraud proof, also verify Merkle proof for each non-nil share. - rebuiltShares, err := codec.Decode(fraudProof.Shares) - if err != nil { - t.Errorf("could not decode fraud proof shares; got: %v", err) - } - root := corrupted.computeSharesRoot(rebuiltShares, byzData.Axis, fraudProof.Index) - if bytes.Equal(root, corrupted.getRowRoot(fraudProof.Index)) { - // If the roots match, then the fraud proof should be for invalid erasure coding. - parityShares, err := codec.Encode(rebuiltShares[0:corrupted.originalDataWidth]) +func TestValidFraudProof(t *testing.T) { + bufferSize := 64 + corruptChunk := bytes.Repeat([]byte{66}, bufferSize) + tests := []struct { + name string + // Size of each share, in bytes + shareSize int + codec Codec + }{ + {"leopard", bufferSize, NewLeoRSCodec()}, + {"infectiousGF8", bufferSize, NewRSGF8Codec()}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + name, codec, shareSize := test.name, test.codec, test.shareSize + original := createTestEds(codec, shareSize) + + var byzData *ErrByzantineData + corrupted, err := original.deepCopy(codec) + if err != nil { + t.Fatalf("unexpected err while copying original data: %v, codec: :%s", err, name) + } + corrupted.setCell(0, 0, corruptChunk) + err = corrupted.Repair(corrupted.getRowRoots(), corrupted.getColRoots()) + errors.As(err, &byzData) + + // Construct the fraud proof + fraudProof := PseudoFraudProof{0, byzData.Index, byzData.Shares} + // Verify the fraud proof + // TODO in a real fraud proof, also verify Merkle proof for each non-nil share. + rebuiltShares, err := codec.Decode(fraudProof.Shares) if err != nil { - t.Errorf("could not encode fraud proof shares; %v", fraudProof) + t.Errorf("could not decode fraud proof shares; got: %v", err) } - startIndex := len(rebuiltShares) - int(corrupted.originalDataWidth) - if bytes.Equal(flattenChunks(parityShares), flattenChunks(rebuiltShares[startIndex:])) { - t.Errorf("invalid fraud proof %v", fraudProof) + root := corrupted.computeSharesRoot(rebuiltShares, byzData.Axis, fraudProof.Index) + if bytes.Equal(root, corrupted.getRowRoot(fraudProof.Index)) { + // If the roots match, then the fraud proof should be for invalid erasure coding. + parityShares, err := codec.Encode(rebuiltShares[0:corrupted.originalDataWidth]) + if err != nil { + t.Errorf("could not encode fraud proof shares; %v", fraudProof) + } + startIndex := len(rebuiltShares) - int(corrupted.originalDataWidth) + if bytes.Equal(flattenChunks(parityShares), flattenChunks(rebuiltShares[startIndex:])) { + t.Errorf("invalid fraud proof %v", fraudProof) + } } - } + }) + } +} - corrupted, err = original.deepCopy(codec) - if err != nil { - t.Fatalf("unexpected err while copying original data: %v, codec: :%s", err, codecName) - } - corrupted.setCell(0, 3, corruptChunk) - err = corrupted.Repair(corrupted.getRowRoots(), corrupted.getColRoots()) - if !errors.As(err, &byzData) { - // due to parallelisation, the ErrByzantineData axis may be either row or col - t.Errorf("did not return a ErrByzantineData for a bad row or col; got %v", err) - } +func TestCannotRepairSquareWithBadRoots(t *testing.T) { + bufferSize := 64 + corruptChunk := bytes.Repeat([]byte{66}, bufferSize) + tests := []struct { + name string + // Size of each share, in bytes + shareSize int + codec Codec + }{ + {"leopard", bufferSize, NewLeoRSCodec()}, + {"infectiousGF8", bufferSize, NewRSGF8Codec()}, + } - corrupted, err = original.deepCopy(codec) - if err != nil { - t.Fatalf("unexpected err while copying original data: %v, codec: :%s", err, codecName) - } - corrupted.setCell(0, 0, corruptChunk) - corrupted.setCell(0, 1, nil) - corrupted.setCell(0, 2, nil) - corrupted.setCell(0, 3, nil) - err = corrupted.Repair(corrupted.getRowRoots(), corrupted.getColRoots()) - if !errors.As(err, &byzData) || byzData.Axis != Col { - t.Errorf("did not return a ErrByzantineData for a bad column; got %v", err) - } - corrupted, err = original.deepCopy(codec) - if err != nil { - t.Fatalf("unexpected err while copying original data: %v, codec: :%s", err, codecName) - } - corrupted.setCell(3, 0, corruptChunk) - corrupted.setCell(0, 1, nil) - corrupted.setCell(0, 2, nil) - corrupted.setCell(0, 3, nil) - err = corrupted.Repair(corrupted.getRowRoots(), corrupted.getColRoots()) - if !errors.As(err, &byzData) { - // due to parallelisation, the ErrByzantineData axis may be either row or col - t.Errorf("did not return a ErrByzantineData for a bad col or row; got %v", err) - } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + codec, shareSize := test.codec, test.shareSize + original := createTestEds(codec, shareSize) + + rowRoots := original.RowRoots() + colRoots := original.ColRoots() + + original.setCell(0, 0, corruptChunk) + err := original.Repair(rowRoots, colRoots) + if err == nil { + t.Errorf("did not return an error on trying to repair a square with bad roots") + } + }) + } +} + +func TestCorruptedEdsReturnsErrByzantineData(t *testing.T) { + bufferSize := 64 + corruptChunk := bytes.Repeat([]byte{66}, bufferSize) + + tests := []struct { + name string + // Size of each share, in bytes + shareSize int + cells [][]byte + values [][]byte + axis Axis + }{ + {"BadRow/OriginalData", bufferSize, [][]byte{{0, 0}}, [][]byte{corruptChunk}, Row}, + {"BadRow/ExtendedData", bufferSize, [][]byte{{0, 3}}, [][]byte{corruptChunk}, Row}, + {"BadColumn/OriginalData", bufferSize, [][]byte{{0, 0}, {0, 1}, {0, 2}, {0, 3}}, [][]byte{corruptChunk, nil, nil, nil}, Col}, + {"BadColumn/OriginalData", bufferSize, [][]byte{{3, 0}, {0, 1}, {0, 2}, {0, 3}}, [][]byte{corruptChunk, nil, nil, nil}, Col}, + } + + for codecName, codec := range codecs { + t.Run(codecName, func(t *testing.T) { + original := createTestEds(codec, bufferSize) + + var byzData *ErrByzantineData + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + corrupted, err := original.deepCopy(codec) + if err != nil { + t.Fatalf("unexpected err while copying original data: %v, codec: :%s", err, codecName) + } + for i := 0; i < len(test.cells); i++ { + corrupted.setCell(uint(test.cells[i][0]), uint(test.cells[i][1]), test.values[i]) + } + err = corrupted.Repair(corrupted.getRowRoots(), corrupted.getColRoots()) + if !errors.As(err, &byzData) { + // due to parallelisation, the ErrByzantineData axis may be either row or col + t.Errorf("did not return a ErrByzantineData for a bad col or row; got %v", err) + } + }) + } + }) } } @@ -222,3 +273,20 @@ func BenchmarkRepair(b *testing.B) { } } } + +func createTestEds(codec Codec, bufferSize int) *ExtendedDataSquare { + ones := bytes.Repeat([]byte{1}, bufferSize) + twos := bytes.Repeat([]byte{2}, bufferSize) + threes := bytes.Repeat([]byte{3}, bufferSize) + fours := bytes.Repeat([]byte{4}, bufferSize) + + eds, err := ComputeExtendedDataSquare([][]byte{ + ones, twos, + threes, fours, + }, codec, NewDefaultTree) + if err != nil { + panic(err) + } + + return eds +}