Skip to content

Commit

Permalink
fix(crypto): made MultiSignature generic
Browse files Browse the repository at this point in the history
This is an attempt at making the MultiSignature type be generic
so that impl can be switched at the top level instead of nested.
  • Loading branch information
meling committed Feb 26, 2024
1 parent 6c6e63d commit 975e03a
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 52 deletions.
19 changes: 12 additions & 7 deletions crypto/ecdsa/ecdsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ const (
PublicKeyFileType = "ECDSA PUBLIC KEY"
)

var (
_ hotstuff.QuorumSignature = (*crypto.MultiSignature[*Signature])(nil)
_ hotstuff.IDSet = (*crypto.MultiSignature[*Signature])(nil)
)

// Signature is an ECDSA signature
type Signature struct {
r, s *big.Int
Expand Down Expand Up @@ -92,7 +97,7 @@ func (ec *ecdsaBase) Sign(message []byte) (signature hotstuff.QuorumSignature, e
if err != nil {
return nil, fmt.Errorf("ecdsa: sign failed: %w", err)
}
return crypto.MultiSignature{ec.opts.ID(): &Signature{
return crypto.MultiSignature[*Signature]{ec.opts.ID(): &Signature{
r: r,
s: s,
signer: ec.opts.ID(),
Expand All @@ -105,10 +110,10 @@ func (ec *ecdsaBase) Combine(signatures ...hotstuff.QuorumSignature) (hotstuff.Q
return nil, crypto.ErrCombineMultiple
}

ts := make(crypto.MultiSignature)
ts := make(crypto.MultiSignature[*Signature])

for _, sig1 := range signatures {
if sig2, ok := sig1.(crypto.MultiSignature); ok {
if sig2, ok := sig1.(crypto.MultiSignature[*Signature]); ok {
for id, s := range sig2 {
if _, ok := ts[id]; ok {
return nil, crypto.ErrCombineOverlap
Expand All @@ -125,7 +130,7 @@ func (ec *ecdsaBase) Combine(signatures ...hotstuff.QuorumSignature) (hotstuff.Q

// Verify verifies the given quorum signature against the message.
func (ec *ecdsaBase) Verify(signature hotstuff.QuorumSignature, message []byte) bool {
s, ok := signature.(crypto.MultiSignature)
s, ok := signature.(crypto.MultiSignature[*Signature])
if !ok {
ec.logger.Panicf("cannot verify signature of incompatible type %T (expected %T)", signature, s)
}
Expand All @@ -141,7 +146,7 @@ func (ec *ecdsaBase) Verify(signature hotstuff.QuorumSignature, message []byte)
for _, sig := range s {
go func(sig *Signature, hash hotstuff.Hash) {
results <- ec.verifySingle(sig, hash)
}(sig.(*Signature), hash)
}(sig, hash)
}

valid := true
Expand All @@ -156,7 +161,7 @@ func (ec *ecdsaBase) Verify(signature hotstuff.QuorumSignature, message []byte)

// BatchVerify verifies the given quorum signature against the batch of messages.
func (ec *ecdsaBase) BatchVerify(signature hotstuff.QuorumSignature, batch map[hotstuff.ID][]byte) bool {
s, ok := signature.(crypto.MultiSignature)
s, ok := signature.(crypto.MultiSignature[*Signature])
if !ok {
ec.logger.Panicf("cannot verify signature of incompatible type %T (expected %T)", signature, s)
}
Expand All @@ -177,7 +182,7 @@ func (ec *ecdsaBase) BatchVerify(signature hotstuff.QuorumSignature, batch map[h
set[hash] = struct{}{}
go func(sig *Signature, hash hotstuff.Hash) {
results <- ec.verifySingle(sig, hash)
}(sig.(*Signature), hash)
}(sig, hash)
}

valid := true
Expand Down
19 changes: 12 additions & 7 deletions crypto/eddsa/eddsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ const (
PublicKeyFileType = "EDDSA PUBLIC KEY"
)

var (
_ hotstuff.QuorumSignature = (*crypto.MultiSignature[*Signature])(nil)
_ hotstuff.IDSet = (*crypto.MultiSignature[*Signature])(nil)
)

// Signature is an ECDSA signature
type Signature struct {
signer hotstuff.ID
Expand Down Expand Up @@ -74,18 +79,18 @@ func (ed *eddsaBase) privateKey() ed25519.PrivateKey {
func (ed *eddsaBase) Sign(message []byte) (signature hotstuff.QuorumSignature, err error) {
sign := ed25519.Sign(ed.privateKey(), message)
eddsaSign := &Signature{signer: ed.opts.ID(), sign: sign}
return crypto.MultiSignature{ed.opts.ID(): eddsaSign}, nil
return crypto.MultiSignature[*Signature]{ed.opts.ID(): eddsaSign}, nil
}

func (ed *eddsaBase) Combine(signatures ...hotstuff.QuorumSignature) (hotstuff.QuorumSignature, error) {
if len(signatures) < 2 {
return nil, crypto.ErrCombineMultiple
}

ts := make(crypto.MultiSignature)
ts := make(crypto.MultiSignature[*Signature])

for _, sig1 := range signatures {
if sig2, ok := sig1.(crypto.MultiSignature); ok {
if sig2, ok := sig1.(crypto.MultiSignature[*Signature]); ok {
for id, s := range sig2 {
if _, ok := ts[id]; ok {
return nil, crypto.ErrCombineOverlap
Expand All @@ -100,7 +105,7 @@ func (ed *eddsaBase) Combine(signatures ...hotstuff.QuorumSignature) (hotstuff.Q
}

func (ed *eddsaBase) Verify(signature hotstuff.QuorumSignature, message []byte) bool {
s, ok := signature.(crypto.MultiSignature)
s, ok := signature.(crypto.MultiSignature[*Signature])
if !ok {
ed.logger.Panicf("cannot verify signature of incompatible type %T (expected %T)", signature, s)
}
Expand All @@ -114,7 +119,7 @@ func (ed *eddsaBase) Verify(signature hotstuff.QuorumSignature, message []byte)
for _, sig := range s {
go func(sig *Signature, msg []byte) {
results <- ed.verifySingle(sig, msg)
}(sig.(*Signature), message)
}(sig, message)
}

valid := true
Expand All @@ -128,7 +133,7 @@ func (ed *eddsaBase) Verify(signature hotstuff.QuorumSignature, message []byte)

}
func (ed *eddsaBase) BatchVerify(signature hotstuff.QuorumSignature, batch map[hotstuff.ID][]byte) bool {
s, ok := signature.(crypto.MultiSignature)
s, ok := signature.(crypto.MultiSignature[*Signature])
if !ok {
ed.logger.Panicf("cannot verify signature of incompatible type %T (expected %T)", signature, s)
}
Expand All @@ -149,7 +154,7 @@ func (ed *eddsaBase) BatchVerify(signature hotstuff.QuorumSignature, batch map[h
set[hash] = struct{}{}
go func(sig *Signature, msg []byte) {
results <- ed.verifySingle(sig, msg)
}(sig.(*Signature), message)
}(sig, message)
}

valid := true
Expand Down
29 changes: 12 additions & 17 deletions crypto/multi_signature.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,19 @@ type Signature interface {
}

// MultiSignature is a set of (partial) signatures.
type MultiSignature map[hotstuff.ID]Signature
type MultiSignature[T Signature] map[hotstuff.ID]T

// RestoreMultiSignature should only be used to restore an existing threshold signature from a set of signatures.
func RestoreMultiSignature(signatures []Signature) MultiSignature {
sig := make(MultiSignature, len(signatures))
func RestoreMultiSignature[T Signature](signatures []T) MultiSignature[T] {
sig := make(MultiSignature[T], len(signatures))
for _, s := range signatures {
sig[s.Signer()] = s
}
return sig
}

// ToBytes returns the object as bytes.
func (sig MultiSignature) ToBytes() []byte {
func (sig MultiSignature[T]) ToBytes() []byte {
var b []byte
// sort by ID to make it deterministic
order := make([]hotstuff.ID, 0, len(sig))
Expand All @@ -41,30 +41,30 @@ func (sig MultiSignature) ToBytes() []byte {
}

// Participants returns the IDs of replicas who participated in the threshold signature.
func (sig MultiSignature) Participants() hotstuff.IDSet {
func (sig MultiSignature[T]) Participants() hotstuff.IDSet {
return sig
}

// Add adds an ID to the set.
func (sig MultiSignature) Add(_ hotstuff.ID) {
func (sig MultiSignature[T]) Add(_ hotstuff.ID) {
panic("not implemented")
}

// Contains returns true if the set contains the ID.
func (sig MultiSignature) Contains(id hotstuff.ID) bool {
func (sig MultiSignature[T]) Contains(id hotstuff.ID) bool {
_, ok := sig[id]
return ok
}

// ForEach calls f for each ID in the set.
func (sig MultiSignature) ForEach(f func(hotstuff.ID)) {
func (sig MultiSignature[T]) ForEach(f func(hotstuff.ID)) {
for id := range sig {
f(id)
}
}

// RangeWhile calls f for each ID in the set until f returns false.
func (sig MultiSignature) RangeWhile(f func(hotstuff.ID) bool) {
func (sig MultiSignature[T]) RangeWhile(f func(hotstuff.ID) bool) {
for id := range sig {
if !f(id) {
break
Expand All @@ -73,22 +73,17 @@ func (sig MultiSignature) RangeWhile(f func(hotstuff.ID) bool) {
}

// Len returns the number of entries in the set.
func (sig MultiSignature) Len() int {
func (sig MultiSignature[T]) Len() int {
return len(sig)
}

func (sig MultiSignature) String() string {
func (sig MultiSignature[T]) String() string {
return hotstuff.IDSetToString(sig)
}

func (sig MultiSignature) Type() reflect.Type {
func (sig MultiSignature[T]) Type() reflect.Type {

Check warning on line 84 in crypto/multi_signature.go

View workflow job for this annotation

GitHub Actions / lint

exported: exported method MultiSignature.Type should have comment or be unexported (revive)
for _, s := range sig {
return reflect.TypeOf(s)
}
return nil
}

var (
_ hotstuff.QuorumSignature = (*MultiSignature)(nil)
_ hotstuff.IDSet = (*MultiSignature)(nil)
)
44 changes: 23 additions & 21 deletions internal/proto/hotstuffpb/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,36 +15,38 @@ import (
func QuorumSignatureToProto(sig hotstuff.QuorumSignature) *QuorumSignature {
signature := &QuorumSignature{}
switch ms := sig.(type) {
case crypto.MultiSignature:
case crypto.MultiSignature[ecdsa.Signature]:
ECDSASigs := make([]*ECDSASignature, 0, sig.Participants().Len())
EDDSASigs := make([]*EDDSASignature, 0, sig.Participants().Len())
for _, p := range ms {
switch s := p.(type) {
case *ecdsa.Signature:
ECDSASigs = append(ECDSASigs, &ECDSASignature{
Signer: uint32(s.Signer()),
R: s.R().Bytes(),
S: s.S().Bytes(),
})
case *eddsa.Signature:
EDDSASigs = append(EDDSASigs, &EDDSASignature{Signer: uint32(s.Signer()), Sig: s.ToBytes()})
}
for _, s := range ms {
ECDSASigs = append(ECDSASigs, &ECDSASignature{
Signer: uint32(s.Signer()),
R: s.R().Bytes(),
S: s.S().Bytes(),
})
}
if len(ECDSASigs) > 0 {
signature.Sig = &QuorumSignature_ECDSASigs{ECDSASigs: &ECDSAMultiSignature{
Sigs: ECDSASigs,
}}
} else {
signature.Sig = &QuorumSignature_EDDSASigs{EDDSASigs: &EDDSAMultiSignature{
Sigs: EDDSASigs,
}}
signature.Sig = &QuorumSignature_ECDSASigs{ECDSASigs: &ECDSAMultiSignature{
Sigs: ECDSASigs,
}}

case crypto.MultiSignature[eddsa.Signature]:
EDDSASigs := make([]*EDDSASignature, 0, sig.Participants().Len())
for _, s := range ms {
EDDSASigs = append(EDDSASigs, &EDDSASignature{Signer: uint32(s.Signer()), Sig: s.ToBytes()})
}
signature.Sig = &QuorumSignature_EDDSASigs{EDDSASigs: &EDDSAMultiSignature{
Sigs: EDDSASigs,
}}

case *bls12.AggregateSignature:
signature.Sig = &QuorumSignature_BLS12Sig{BLS12Sig: &BLS12AggregateSignature{
Sig: ms.ToBytes(),
Participants: ms.Bitfield().Bytes(),
}}

default:
signature.Sig = &QuorumSignature_ECDSASigs{ECDSASigs: &ECDSAMultiSignature{
Sigs: nil,
}}
}
return signature
}
Expand Down

0 comments on commit 975e03a

Please sign in to comment.