From 975e03a572650fdad748c47ae8507c5eb4aa60df Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Mon, 26 Feb 2024 14:43:15 +0100 Subject: [PATCH] fix(crypto): made MultiSignature generic This is an attempt at making the MultiSignature type be generic so that impl can be switched at the top level instead of nested. --- crypto/ecdsa/ecdsa.go | 19 +++++++----- crypto/eddsa/eddsa.go | 19 +++++++----- crypto/multi_signature.go | 29 ++++++++---------- internal/proto/hotstuffpb/convert.go | 44 +++++++++++++++------------- 4 files changed, 59 insertions(+), 52 deletions(-) diff --git a/crypto/ecdsa/ecdsa.go b/crypto/ecdsa/ecdsa.go index ff91b788..f2a60acd 100644 --- a/crypto/ecdsa/ecdsa.go +++ b/crypto/ecdsa/ecdsa.go @@ -26,6 +26,11 @@ const ( PublicKeyFileType = "ECDSA PUBLIC KEY" ) +var ( + _ hotstuff.QuorumSignature = (*crypto.MultiSignature[*Signature])(nil) + _ hotstuff.IDSet = (*crypto.MultiSignature[*Signature])(nil) +) + // Signature is an ECDSA signature type Signature struct { r, s *big.Int @@ -92,7 +97,7 @@ func (ec *ecdsaBase) Sign(message []byte) (signature hotstuff.QuorumSignature, e if err != nil { return nil, fmt.Errorf("ecdsa: sign failed: %w", err) } - return crypto.MultiSignature{ec.opts.ID(): &Signature{ + return crypto.MultiSignature[*Signature]{ec.opts.ID(): &Signature{ r: r, s: s, signer: ec.opts.ID(), @@ -105,10 +110,10 @@ func (ec *ecdsaBase) Combine(signatures ...hotstuff.QuorumSignature) (hotstuff.Q return nil, crypto.ErrCombineMultiple } - ts := make(crypto.MultiSignature) + ts := make(crypto.MultiSignature[*Signature]) for _, sig1 := range signatures { - if sig2, ok := sig1.(crypto.MultiSignature); ok { + if sig2, ok := sig1.(crypto.MultiSignature[*Signature]); ok { for id, s := range sig2 { if _, ok := ts[id]; ok { return nil, crypto.ErrCombineOverlap @@ -125,7 +130,7 @@ func (ec *ecdsaBase) Combine(signatures ...hotstuff.QuorumSignature) (hotstuff.Q // Verify verifies the given quorum signature against the message. func (ec *ecdsaBase) Verify(signature hotstuff.QuorumSignature, message []byte) bool { - s, ok := signature.(crypto.MultiSignature) + s, ok := signature.(crypto.MultiSignature[*Signature]) if !ok { ec.logger.Panicf("cannot verify signature of incompatible type %T (expected %T)", signature, s) } @@ -141,7 +146,7 @@ func (ec *ecdsaBase) Verify(signature hotstuff.QuorumSignature, message []byte) for _, sig := range s { go func(sig *Signature, hash hotstuff.Hash) { results <- ec.verifySingle(sig, hash) - }(sig.(*Signature), hash) + }(sig, hash) } valid := true @@ -156,7 +161,7 @@ func (ec *ecdsaBase) Verify(signature hotstuff.QuorumSignature, message []byte) // BatchVerify verifies the given quorum signature against the batch of messages. func (ec *ecdsaBase) BatchVerify(signature hotstuff.QuorumSignature, batch map[hotstuff.ID][]byte) bool { - s, ok := signature.(crypto.MultiSignature) + s, ok := signature.(crypto.MultiSignature[*Signature]) if !ok { ec.logger.Panicf("cannot verify signature of incompatible type %T (expected %T)", signature, s) } @@ -177,7 +182,7 @@ func (ec *ecdsaBase) BatchVerify(signature hotstuff.QuorumSignature, batch map[h set[hash] = struct{}{} go func(sig *Signature, hash hotstuff.Hash) { results <- ec.verifySingle(sig, hash) - }(sig.(*Signature), hash) + }(sig, hash) } valid := true diff --git a/crypto/eddsa/eddsa.go b/crypto/eddsa/eddsa.go index 5c59fb3c..bd2b4975 100644 --- a/crypto/eddsa/eddsa.go +++ b/crypto/eddsa/eddsa.go @@ -23,6 +23,11 @@ const ( PublicKeyFileType = "EDDSA PUBLIC KEY" ) +var ( + _ hotstuff.QuorumSignature = (*crypto.MultiSignature[*Signature])(nil) + _ hotstuff.IDSet = (*crypto.MultiSignature[*Signature])(nil) +) + // Signature is an ECDSA signature type Signature struct { signer hotstuff.ID @@ -74,7 +79,7 @@ func (ed *eddsaBase) privateKey() ed25519.PrivateKey { func (ed *eddsaBase) Sign(message []byte) (signature hotstuff.QuorumSignature, err error) { sign := ed25519.Sign(ed.privateKey(), message) eddsaSign := &Signature{signer: ed.opts.ID(), sign: sign} - return crypto.MultiSignature{ed.opts.ID(): eddsaSign}, nil + return crypto.MultiSignature[*Signature]{ed.opts.ID(): eddsaSign}, nil } func (ed *eddsaBase) Combine(signatures ...hotstuff.QuorumSignature) (hotstuff.QuorumSignature, error) { @@ -82,10 +87,10 @@ func (ed *eddsaBase) Combine(signatures ...hotstuff.QuorumSignature) (hotstuff.Q return nil, crypto.ErrCombineMultiple } - ts := make(crypto.MultiSignature) + ts := make(crypto.MultiSignature[*Signature]) for _, sig1 := range signatures { - if sig2, ok := sig1.(crypto.MultiSignature); ok { + if sig2, ok := sig1.(crypto.MultiSignature[*Signature]); ok { for id, s := range sig2 { if _, ok := ts[id]; ok { return nil, crypto.ErrCombineOverlap @@ -100,7 +105,7 @@ func (ed *eddsaBase) Combine(signatures ...hotstuff.QuorumSignature) (hotstuff.Q } func (ed *eddsaBase) Verify(signature hotstuff.QuorumSignature, message []byte) bool { - s, ok := signature.(crypto.MultiSignature) + s, ok := signature.(crypto.MultiSignature[*Signature]) if !ok { ed.logger.Panicf("cannot verify signature of incompatible type %T (expected %T)", signature, s) } @@ -114,7 +119,7 @@ func (ed *eddsaBase) Verify(signature hotstuff.QuorumSignature, message []byte) for _, sig := range s { go func(sig *Signature, msg []byte) { results <- ed.verifySingle(sig, msg) - }(sig.(*Signature), message) + }(sig, message) } valid := true @@ -128,7 +133,7 @@ func (ed *eddsaBase) Verify(signature hotstuff.QuorumSignature, message []byte) } func (ed *eddsaBase) BatchVerify(signature hotstuff.QuorumSignature, batch map[hotstuff.ID][]byte) bool { - s, ok := signature.(crypto.MultiSignature) + s, ok := signature.(crypto.MultiSignature[*Signature]) if !ok { ed.logger.Panicf("cannot verify signature of incompatible type %T (expected %T)", signature, s) } @@ -149,7 +154,7 @@ func (ed *eddsaBase) BatchVerify(signature hotstuff.QuorumSignature, batch map[h set[hash] = struct{}{} go func(sig *Signature, msg []byte) { results <- ed.verifySingle(sig, msg) - }(sig.(*Signature), message) + }(sig, message) } valid := true diff --git a/crypto/multi_signature.go b/crypto/multi_signature.go index c6c92069..3793a4a9 100644 --- a/crypto/multi_signature.go +++ b/crypto/multi_signature.go @@ -14,11 +14,11 @@ type Signature interface { } // MultiSignature is a set of (partial) signatures. -type MultiSignature map[hotstuff.ID]Signature +type MultiSignature[T Signature] map[hotstuff.ID]T // RestoreMultiSignature should only be used to restore an existing threshold signature from a set of signatures. -func RestoreMultiSignature(signatures []Signature) MultiSignature { - sig := make(MultiSignature, len(signatures)) +func RestoreMultiSignature[T Signature](signatures []T) MultiSignature[T] { + sig := make(MultiSignature[T], len(signatures)) for _, s := range signatures { sig[s.Signer()] = s } @@ -26,7 +26,7 @@ func RestoreMultiSignature(signatures []Signature) MultiSignature { } // ToBytes returns the object as bytes. -func (sig MultiSignature) ToBytes() []byte { +func (sig MultiSignature[T]) ToBytes() []byte { var b []byte // sort by ID to make it deterministic order := make([]hotstuff.ID, 0, len(sig)) @@ -41,30 +41,30 @@ func (sig MultiSignature) ToBytes() []byte { } // Participants returns the IDs of replicas who participated in the threshold signature. -func (sig MultiSignature) Participants() hotstuff.IDSet { +func (sig MultiSignature[T]) Participants() hotstuff.IDSet { return sig } // Add adds an ID to the set. -func (sig MultiSignature) Add(_ hotstuff.ID) { +func (sig MultiSignature[T]) Add(_ hotstuff.ID) { panic("not implemented") } // Contains returns true if the set contains the ID. -func (sig MultiSignature) Contains(id hotstuff.ID) bool { +func (sig MultiSignature[T]) Contains(id hotstuff.ID) bool { _, ok := sig[id] return ok } // ForEach calls f for each ID in the set. -func (sig MultiSignature) ForEach(f func(hotstuff.ID)) { +func (sig MultiSignature[T]) ForEach(f func(hotstuff.ID)) { for id := range sig { f(id) } } // RangeWhile calls f for each ID in the set until f returns false. -func (sig MultiSignature) RangeWhile(f func(hotstuff.ID) bool) { +func (sig MultiSignature[T]) RangeWhile(f func(hotstuff.ID) bool) { for id := range sig { if !f(id) { break @@ -73,22 +73,17 @@ func (sig MultiSignature) RangeWhile(f func(hotstuff.ID) bool) { } // Len returns the number of entries in the set. -func (sig MultiSignature) Len() int { +func (sig MultiSignature[T]) Len() int { return len(sig) } -func (sig MultiSignature) String() string { +func (sig MultiSignature[T]) String() string { return hotstuff.IDSetToString(sig) } -func (sig MultiSignature) Type() reflect.Type { +func (sig MultiSignature[T]) Type() reflect.Type { for _, s := range sig { return reflect.TypeOf(s) } return nil } - -var ( - _ hotstuff.QuorumSignature = (*MultiSignature)(nil) - _ hotstuff.IDSet = (*MultiSignature)(nil) -) diff --git a/internal/proto/hotstuffpb/convert.go b/internal/proto/hotstuffpb/convert.go index 3256dd79..87291ebe 100644 --- a/internal/proto/hotstuffpb/convert.go +++ b/internal/proto/hotstuffpb/convert.go @@ -15,36 +15,38 @@ import ( func QuorumSignatureToProto(sig hotstuff.QuorumSignature) *QuorumSignature { signature := &QuorumSignature{} switch ms := sig.(type) { - case crypto.MultiSignature: + case crypto.MultiSignature[ecdsa.Signature]: ECDSASigs := make([]*ECDSASignature, 0, sig.Participants().Len()) - EDDSASigs := make([]*EDDSASignature, 0, sig.Participants().Len()) - for _, p := range ms { - switch s := p.(type) { - case *ecdsa.Signature: - ECDSASigs = append(ECDSASigs, &ECDSASignature{ - Signer: uint32(s.Signer()), - R: s.R().Bytes(), - S: s.S().Bytes(), - }) - case *eddsa.Signature: - EDDSASigs = append(EDDSASigs, &EDDSASignature{Signer: uint32(s.Signer()), Sig: s.ToBytes()}) - } + for _, s := range ms { + ECDSASigs = append(ECDSASigs, &ECDSASignature{ + Signer: uint32(s.Signer()), + R: s.R().Bytes(), + S: s.S().Bytes(), + }) } - if len(ECDSASigs) > 0 { - signature.Sig = &QuorumSignature_ECDSASigs{ECDSASigs: &ECDSAMultiSignature{ - Sigs: ECDSASigs, - }} - } else { - signature.Sig = &QuorumSignature_EDDSASigs{EDDSASigs: &EDDSAMultiSignature{ - Sigs: EDDSASigs, - }} + signature.Sig = &QuorumSignature_ECDSASigs{ECDSASigs: &ECDSAMultiSignature{ + Sigs: ECDSASigs, + }} + + case crypto.MultiSignature[eddsa.Signature]: + EDDSASigs := make([]*EDDSASignature, 0, sig.Participants().Len()) + for _, s := range ms { + EDDSASigs = append(EDDSASigs, &EDDSASignature{Signer: uint32(s.Signer()), Sig: s.ToBytes()}) } + signature.Sig = &QuorumSignature_EDDSASigs{EDDSASigs: &EDDSAMultiSignature{ + Sigs: EDDSASigs, + }} case *bls12.AggregateSignature: signature.Sig = &QuorumSignature_BLS12Sig{BLS12Sig: &BLS12AggregateSignature{ Sig: ms.ToBytes(), Participants: ms.Bitfield().Bytes(), }} + + default: + signature.Sig = &QuorumSignature_ECDSASigs{ECDSASigs: &ECDSAMultiSignature{ + Sigs: nil, + }} } return signature }