diff --git a/block.go b/block.go index 062c9beb..e731e087 100644 --- a/block.go +++ b/block.go @@ -4,6 +4,9 @@ import ( "crypto/sha256" "encoding/binary" "fmt" + "io" + + "github.com/relab/hotstuff/util" ) // Block contains a propsed "command", metadata for the protocol, and a link to the "parent" block. @@ -26,8 +29,13 @@ func NewBlock(parent Hash, cert QuorumCert, cmd Command, view View, proposer ID) view: view, proposer: proposer, } + hasher := sha256.New() + _, err := b.WriteTo(hasher) + if err != nil { + panic("unexpected error: " + err.Error()) + } // cache the hash immediately because it is too racy to do it in Hash() - b.hash = sha256.Sum256(b.ToBytes()) + hasher.Sum(b.hash[:0]) return b } @@ -72,16 +80,19 @@ func (b *Block) View() View { return b.view } -// ToBytes returns the raw byte form of the Block, to be used for hashing, etc. -func (b *Block) ToBytes() []byte { - buf := b.parent[:] +// WriteTo writes the block data to the writer. +func (b *Block) WriteTo(writer io.Writer) (n int64, err error) { var proposerBuf [4]byte binary.LittleEndian.PutUint32(proposerBuf[:], uint32(b.proposer)) - buf = append(buf, proposerBuf[:]...) + var viewBuf [8]byte binary.LittleEndian.PutUint64(viewBuf[:], uint64(b.view)) - buf = append(buf, viewBuf[:]...) - buf = append(buf, []byte(b.cmd)...) - buf = append(buf, b.cert.ToBytes()...) - return buf + + return util.WriteAllTo( + writer, + b.parent[:], + proposerBuf[:], + b.Command(), + b.cert, + ) } diff --git a/crypto/crypto.go b/crypto/crypto.go index 0a7311c8..de4403fe 100644 --- a/crypto/crypto.go +++ b/crypto/crypto.go @@ -2,8 +2,11 @@ package crypto import ( + "bytes" + "github.com/relab/hotstuff" "github.com/relab/hotstuff/modules" + "github.com/relab/hotstuff/util/gpool" ) type crypto struct { @@ -34,9 +37,22 @@ func (c *crypto) InitModule(mods *modules.Core) { } } +var bufferPool gpool.Pool[bytes.Buffer] + // CreatePartialCert signs a single block and returns the partial certificate. func (c crypto) CreatePartialCert(block *hotstuff.Block) (cert hotstuff.PartialCert, err error) { - sig, err := c.Sign(block.ToBytes()) + buf := bufferPool.Get() + _, err = block.WriteTo(&buf) + if err != nil { + return cert, err + } + + defer func() { + buf.Reset() + bufferPool.Put(buf) + }() + + sig, err := c.Sign(buf.Bytes()) if err != nil { return hotstuff.PartialCert{}, err } @@ -102,7 +118,19 @@ func (c crypto) VerifyPartialCert(cert hotstuff.PartialCert) bool { if !ok { return false } - return c.Verify(cert.Signature(), block.ToBytes()) + + buf := bufferPool.Get() + _, err := block.WriteTo(&buf) + if err != nil { + return false + } + + defer func() { + buf.Reset() + bufferPool.Put(buf) + }() + + return c.Verify(cert.Signature(), buf.Bytes()) } // VerifyQuorumCert verifies a quorum certificate. @@ -118,7 +146,19 @@ func (c crypto) VerifyQuorumCert(qc hotstuff.QuorumCert) bool { if !ok { return false } - return c.Verify(qc.Signature(), block.ToBytes()) + + buf := bufferPool.Get() + _, err := block.WriteTo(&buf) + if err != nil { + return false + } + + defer func() { + buf.Reset() + bufferPool.Put(buf) + }() + + return c.Verify(qc.Signature(), buf.Bytes()) } // VerifyTimeoutCert verifies a timeout certificate. diff --git a/crypto/ecdsa/ecdsa.go b/crypto/ecdsa/ecdsa.go index 8930bd05..961e0b9b 100644 --- a/crypto/ecdsa/ecdsa.go +++ b/crypto/ecdsa/ecdsa.go @@ -6,12 +6,14 @@ import ( "crypto/rand" "crypto/sha256" "fmt" + "io" "math/big" "github.com/relab/hotstuff" "github.com/relab/hotstuff/crypto" "github.com/relab/hotstuff/logging" "github.com/relab/hotstuff/modules" + "github.com/relab/hotstuff/util" "golang.org/x/exp/slices" ) @@ -61,6 +63,11 @@ func (sig Signature) ToBytes() []byte { return b } +// WriteTo writes the signature to the writer. +func (sig Signature) WriteTo(writer io.Writer) (n int64, err error) { + return util.WriteAllTo(writer, sig.r, sig.s) +} + // MultiSignature is a set of (partial) signatures. type MultiSignature map[hotstuff.ID]*Signature @@ -88,6 +95,25 @@ func (sig MultiSignature) ToBytes() []byte { return b } +// WriteTo writes the multi signature to the writer. +func (sig MultiSignature) WriteTo(writer io.Writer) (n int64, err error) { + // sort by ID to make it deterministic + order := make([]hotstuff.ID, 0, len(sig)) + for _, signature := range sig { + order = append(order, signature.signer) + } + slices.Sort(order) + var nn int64 + for _, id := range order { + nn, err = sig[id].WriteTo(writer) + n += nn + if err != nil { + return n, err + } + } + return n, nil +} + // Participants returns the IDs of replicas who participated in the threshold signature. func (sig MultiSignature) Participants() hotstuff.IDSet { return sig diff --git a/handel/session.go b/handel/session.go index 668dc524..87a22b24 100644 --- a/handel/session.go +++ b/handel/session.go @@ -1,6 +1,7 @@ package handel import ( + "bytes" "context" "encoding/binary" "math/rand" @@ -14,6 +15,7 @@ import ( "github.com/relab/hotstuff/internal/proto/handelpb" "github.com/relab/hotstuff/internal/proto/hotstuffpb" "github.com/relab/hotstuff/synchronizer" + "github.com/relab/hotstuff/util/gpool" ) const ( @@ -624,16 +626,29 @@ func (s *session) improveSignature(contribution contribution) hotstuff.QuorumSig return signature } +var bufferPool gpool.Pool[bytes.Buffer] + func (s *session) verifyContribution(c contribution, sig hotstuff.QuorumSignature, verifyIndiv bool) { block, ok := s.h.blockChain.Get(s.hash) if !ok { return } + buf := bufferPool.Get() + _, err := block.WriteTo(&buf) + if err != nil { + return + } + + defer func() { + buf.Reset() + bufferPool.Put(buf) + }() + s.h.logger.Debugf("verifying: %v (= %d)", sig.Participants(), sig.Participants().Len()) aggVerified := false - if s.h.crypto.Verify(sig, block.ToBytes()) { + if s.h.crypto.Verify(sig, buf.Bytes()) { aggVerified = true } else { s.h.logger.Debug("failed to verify aggregate signature") @@ -642,7 +657,7 @@ func (s *session) verifyContribution(c contribution, sig hotstuff.QuorumSignatur indivVerified := false // If the contribution is individual, we want to verify it separately if verifyIndiv { - if s.h.crypto.Verify(c.individual, block.ToBytes()) { + if s.h.crypto.Verify(c.individual, buf.Bytes()) { indivVerified = true } else { s.h.logger.Debug("failed to verify individual signature") diff --git a/types.go b/types.go index 51cbfe8c..3e3551b2 100644 --- a/types.go +++ b/types.go @@ -9,6 +9,8 @@ import ( "io" "strconv" "strings" + + "github.com/relab/hotstuff/util" ) // IDSet implements a set of replica IDs. It is used to show which replicas participated in some event. @@ -101,7 +103,7 @@ func (h Hash) String() string { // Command is a client request to be executed by the consensus protocol. // // The string type is used because it is immutable and can hold arbitrary bytes of any length. -type Command string +type Command = string // ToBytes is an object that can be converted into bytes for the purposes of hashing, etc. type ToBytes interface { @@ -256,6 +258,16 @@ func NewQuorumCert(signature QuorumSignature, view View, hash Hash) QuorumCert { return QuorumCert{signature, view, hash} } +// WriteTo writes the quorum certificate to the writer. +func (qc QuorumCert) WriteTo(writer io.Writer) (n int64, err error) { + return util.WriteAllTo( + writer, + qc.view.ToBytes(), + qc.hash[:], + qc.signature, + ) +} + // ToBytes returns a byte representation of the quorum certificate. func (qc QuorumCert) ToBytes() []byte { b := qc.view.ToBytes() diff --git a/util/io.go b/util/io.go new file mode 100644 index 00000000..8a4cee57 --- /dev/null +++ b/util/io.go @@ -0,0 +1,58 @@ +package util + +import ( + "fmt" + "io" + "reflect" + "unsafe" +) + +type toBytes interface { + ToBytes() []byte +} + +type bytes interface { + Bytes() []byte +} + +// WriteAllTo writes all the data to the writer. +func WriteAllTo(writer io.Writer, data ...any) (n int64, err error) { + for _, d := range data { + var ( + nn int64 + nnn int + ) + switch d := d.(type) { + case io.WriterTo: + nn, err = d.WriteTo(writer) + case string: + nnn, err = writer.Write(unsafeStringToBytes(d)) + case []byte: + nnn, err = writer.Write(d) + case toBytes: + nnn, err = writer.Write(d.ToBytes()) + case bytes: + nnn, err = writer.Write(d.Bytes()) + case nil: + default: + panic(fmt.Sprintf("cannot write %T", d)) + } + nn += int64(nnn) + n += int64(nn) + if err != nil { + return n, err + } + } + return n, nil +} + +func unsafeStringToBytes(s string) []byte { + if s == "" { + return []byte{} + } + const max = 0x7fff0000 + if len(s) > max { + panic("string too long") + } + return (*[max]byte)(unsafe.Pointer((*reflect.StringHeader)(unsafe.Pointer(&s)).Data))[:len(s):len(s)] +}