diff --git a/backend/backend_test.go b/backend/backend_test.go index 0f0c5eb6..bbab0367 100644 --- a/backend/backend_test.go +++ b/backend/backend_test.go @@ -27,7 +27,7 @@ func TestConnect(t *testing.T) { const n = 4 ctrl := gomock.NewController(t) td := setup(t, ctrl, n) - builder := modules.NewConsensusBuilder(1, td.keys[0]) + builder := modules.NewBuilder(1, td.keys[0]) testutil.TestModules(t, ctrl, 1, td.keys[0], &builder) teardown := createServers(t, td, ctrl) defer teardown() @@ -35,7 +35,7 @@ func TestConnect(t *testing.T) { cfg := NewConfig(td.creds, gorums.WithDialTimeout(time.Second)) - builder.Register(cfg) + builder.Add(cfg) builder.Build() err := cfg.Connect(td.replicas) @@ -58,7 +58,7 @@ func testBase(t *testing.T, typ any, send func(modules.Configuration), handle ev defer serverTeardown() cfg := NewConfig(td.creds, gorums.WithDialTimeout(time.Second)) - td.builders[0].Register(cfg) + td.builders[0].Add(cfg) hl := td.builders.Build() err := cfg.Connect(td.replicas) @@ -69,8 +69,14 @@ func testBase(t *testing.T, typ any, send func(modules.Configuration), handle ev ctx, cancel := context.WithCancel(context.Background()) for _, hs := range hl[1:] { - hs.EventLoop().RegisterHandler(typ, handle) - go hs.Run(ctx) + var ( + eventLoop *eventloop.EventLoop + synchronizer modules.Synchronizer + ) + hs.GetAll(&eventLoop, &synchronizer) + eventLoop.RegisterHandler(typ, handle) + synchronizer.Start(ctx) + go eventLoop.Run(ctx) } send(cfg) cancel() @@ -219,7 +225,7 @@ func createServers(t *testing.T, td testData, ctrl *gomock.Controller) (teardown for i := range servers { servers[i] = NewServer(gorums.WithGRPCServerOptions(grpc.Creds(td.creds))) servers[i].StartOnListener(td.listeners[i]) - td.builders[i].Register(servers[i]) + td.builders[i].Add(servers[i]) } return func() { for _, srv := range servers { diff --git a/backend/config.go b/backend/config.go index d09c2df5..f79cd65b 100644 --- a/backend/config.go +++ b/backend/config.go @@ -7,7 +7,10 @@ import ( "fmt" "strings" + "github.com/relab/hotstuff/eventloop" + "github.com/relab/hotstuff/logging" "github.com/relab/hotstuff/modules" + "github.com/relab/hotstuff/synchronizer" "github.com/relab/gorums" "github.com/relab/hotstuff" @@ -21,12 +24,11 @@ import ( // Replica provides methods used by hotstuff to send messages to replicas. type Replica struct { - node *hotstuffpb.Node - id hotstuff.ID - pubKey hotstuff.PublicKey - voteCancel context.CancelFunc - newViewCancel context.CancelFunc - md map[string]string + eventLoop *eventloop.EventLoop + node *hotstuffpb.Node + id hotstuff.ID + pubKey hotstuff.PublicKey + md map[string]string } // ID returns the replica's ID. @@ -44,11 +46,10 @@ func (r *Replica) Vote(cert hotstuff.PartialCert) { if r.node == nil { return } - var ctx context.Context - r.voteCancel() - ctx, r.voteCancel = context.WithCancel(context.Background()) + ctx, cancel := synchronizer.TimeoutContext(r.eventLoop.Context(), r.eventLoop) + defer cancel() pCert := hotstuffpb.PartialCertToProto(cert) - r.node.Vote(ctx, pCert, gorums.WithNoSendWaiting()) + r.node.Vote(ctx, pCert) } // NewView sends the quorum certificate to the other replica. @@ -56,10 +57,9 @@ func (r *Replica) NewView(msg hotstuff.SyncInfo) { if r.node == nil { return } - var ctx context.Context - r.newViewCancel() - ctx, r.newViewCancel = context.WithCancel(context.Background()) - r.node.NewView(ctx, hotstuffpb.SyncInfoToProto(msg), gorums.WithNoSendWaiting()) + ctx, cancel := synchronizer.TimeoutContext(r.eventLoop.Context(), r.eventLoop) + defer cancel() + r.node.NewView(ctx, hotstuffpb.SyncInfoToProto(msg)) } // Metadata returns the gRPC metadata from this replica's connection. @@ -78,20 +78,26 @@ type Config struct { } type subConfig struct { - mods *modules.ConsensusCore + eventLoop *eventloop.EventLoop + logger logging.Logger + opts *modules.Options + cfg *hotstuffpb.Configuration replicas map[hotstuff.ID]modules.Replica } -// InitModule gives the module a reference to the ConsensusCore object. -// It also allows the module to set module options using the OptionsBuilder. -func (cfg *Config) InitModule(mods *modules.ConsensusCore, _ *modules.OptionsBuilder) { - cfg.mods = mods +// InitModule initializes the configuration. +func (cfg *Config) InitModule(mods *modules.Core) { + mods.GetAll( + &cfg.eventLoop, + &cfg.logger, + &cfg.subConfig.opts, + ) // We delay processing `replicaConnected` events until after the configurations `connected` event has occurred. - cfg.mods.EventLoop().RegisterHandler(replicaConnected{}, func(event any) { + cfg.eventLoop.RegisterHandler(replicaConnected{}, func(event any) { if !cfg.connected { - cfg.mods.EventLoop().DelayUntil(connected{}, event) + cfg.eventLoop.DelayUntil(ConnectedEvent{}, event) return } cfg.replicaConnected(event.(replicaConnected)) @@ -129,19 +135,19 @@ func (cfg *Config) replicaConnected(c replicaConnected) { id, err := GetPeerIDFromContext(c.ctx, cfg) if err != nil { - cfg.mods.Logger().Warnf("Failed to get id for %v: %v", info.Addr, err) + cfg.logger.Warnf("Failed to get id for %v: %v", info.Addr, err) return } replica, ok := cfg.replicas[id] if !ok { - cfg.mods.Logger().Warnf("Replica with id %d was not found", id) + cfg.logger.Warnf("Replica with id %d was not found", id) return } replica.(*Replica).md = readMetadata(md) - cfg.mods.Logger().Debugf("Replica %d connected from address %v", id, info.Addr) + cfg.logger.Debugf("Replica %d connected from address %v", id, info.Addr) } const keyPrefix = "hotstuff-" @@ -181,10 +187,10 @@ func (cfg *Config) Connect(replicas []ReplicaInfo) (err error) { opts := cfg.opts cfg.opts = nil // options are not needed beyond this point, so we delete them. - md := mapToMetadata(cfg.mods.Options().ConnectionMetadata()) + md := mapToMetadata(cfg.subConfig.opts.ConnectionMetadata()) // embed own ID to allow other replicas to identify messages from this replica - md.Set("id", fmt.Sprintf("%d", cfg.mods.ID())) + md.Set("id", fmt.Sprintf("%d", cfg.subConfig.opts.ID())) opts = append(opts, gorums.WithMetadata(md)) @@ -195,14 +201,13 @@ func (cfg *Config) Connect(replicas []ReplicaInfo) (err error) { for _, replica := range replicas { // also initialize Replica structures cfg.replicas[replica.ID] = &Replica{ - id: replica.ID, - pubKey: replica.PubKey, - newViewCancel: func() {}, - voteCancel: func() {}, - md: make(map[string]string), + eventLoop: cfg.eventLoop, + id: replica.ID, + pubKey: replica.PubKey, + md: make(map[string]string), } // we do not want to connect to ourself - if replica.ID != cfg.mods.ID() { + if replica.ID != cfg.subConfig.opts.ID() { idMapping[replica.Address] = uint32(replica.ID) } } @@ -225,7 +230,7 @@ func (cfg *Config) Connect(replicas []ReplicaInfo) (err error) { cfg.connected = true // this event is sent so that any delayed `replicaConnected` events can be processed. - cfg.mods.EventLoop().AddEvent(connected{}) + cfg.eventLoop.AddEvent(ConnectedEvent{}) return nil } @@ -254,9 +259,11 @@ func (cfg *Config) SubConfig(ids []hotstuff.ID) (sub modules.Configuration, err return nil, err } return &subConfig{ - mods: cfg.mods, - cfg: newCfg, - replicas: replicas, + eventLoop: cfg.eventLoop, + logger: cfg.logger, + opts: cfg.subConfig.opts, + cfg: newCfg, + replicas: replicas, }, nil } @@ -279,10 +286,11 @@ func (cfg *subConfig) Propose(proposal hotstuff.ProposeMsg) { if cfg.cfg == nil { return } + ctx, cancel := synchronizer.TimeoutContext(cfg.eventLoop.Context(), cfg.eventLoop) + defer cancel() cfg.cfg.Propose( - cfg.mods.Synchronizer().ViewContext(), + ctx, hotstuffpb.ProposalToProto(proposal), - gorums.WithNoSendWaiting(), ) } @@ -291,10 +299,14 @@ func (cfg *subConfig) Timeout(msg hotstuff.TimeoutMsg) { if cfg.cfg == nil { return } + + // will wait until the second timeout before cancelling + ctx, cancel := synchronizer.TimeoutContext(cfg.eventLoop.Context(), cfg.eventLoop) + defer cancel() + cfg.cfg.Timeout( - cfg.mods.Synchronizer().ViewContext(), + ctx, hotstuffpb.TimeoutMsgToProto(msg), - gorums.WithNoSendWaiting(), ) } @@ -305,7 +317,7 @@ func (cfg *subConfig) Fetch(ctx context.Context, hash hotstuff.Hash) (*hotstuff. qcErr, ok := err.(gorums.QuorumCallError) // filter out context errors if !ok || (qcErr.Reason != context.Canceled.Error() && qcErr.Reason != context.DeadlineExceeded.Error()) { - cfg.mods.Logger().Infof("Failed to fetch block: %v", err) + cfg.logger.Infof("Failed to fetch block: %v", err) } return nil, false } @@ -335,4 +347,5 @@ func (q qspec) FetchQF(in *hotstuffpb.BlockHash, replies map[uint32]*hotstuffpb. return nil, false } -type connected struct{} +// ConnectedEvent is sent when the configuration has connected to the other replicas. +type ConnectedEvent struct{} diff --git a/backend/server.go b/backend/server.go index ca2ba5cc..7d48cce3 100644 --- a/backend/server.go +++ b/backend/server.go @@ -3,10 +3,13 @@ package backend import ( "context" "fmt" - "github.com/relab/hotstuff/modules" "net" "strconv" + "github.com/relab/hotstuff/eventloop" + "github.com/relab/hotstuff/logging" + "github.com/relab/hotstuff/modules" + "github.com/relab/gorums" "github.com/relab/hotstuff" "github.com/relab/hotstuff/internal/proto/hotstuffpb" @@ -20,14 +23,22 @@ import ( // Server is the Server-side of the gorums backend. // It is responsible for calling handler methods on the consensus instance. type Server struct { - mods *modules.ConsensusCore + blockChain modules.BlockChain + configuration modules.Configuration + eventLoop *eventloop.EventLoop + logger logging.Logger + gorumsSrv *gorums.Server } -// InitModule gives the module a reference to the ConsensusCore object. -// It also allows the module to set module options using the OptionsBuilder. -func (srv *Server) InitModule(mods *modules.ConsensusCore, _ *modules.OptionsBuilder) { - srv.mods = mods +// InitModule initializes the Server. +func (srv *Server) InitModule(mods *modules.Core) { + mods.GetAll( + &srv.eventLoop, + &srv.configuration, + &srv.blockChain, + &srv.logger, + ) } // NewServer creates a new Server. @@ -35,7 +46,7 @@ func NewServer(opts ...gorums.ServerOption) *Server { srv := &Server{} opts = append(opts, gorums.WithConnectCallback(func(ctx context.Context) { - srv.mods.EventLoop().AddEvent(replicaConnected{ctx}) + srv.eventLoop.AddEvent(replicaConnected{ctx}) })) srv.gorumsSrv = gorums.NewServer(opts...) @@ -64,7 +75,7 @@ func (srv *Server) StartOnListener(listener net.Listener) { go func() { err := srv.gorumsSrv.Serve(listener) if err != nil { - srv.mods.Logger().Errorf("An error occurred while serving: %v", err) + srv.logger.Errorf("An error occurred while serving: %v", err) } }() } @@ -123,9 +134,9 @@ type serviceImpl struct { // Propose handles a replica's response to the Propose QC from the leader. func (impl *serviceImpl) Propose(ctx gorums.ServerCtx, proposal *hotstuffpb.Proposal) { - id, err := GetPeerIDFromContext(ctx, impl.srv.mods.Configuration()) + id, err := GetPeerIDFromContext(ctx, impl.srv.configuration) if err != nil { - impl.srv.mods.Logger().Infof("Failed to get client ID: %v", err) + impl.srv.logger.Infof("Failed to get client ID: %v", err) return } @@ -133,18 +144,18 @@ func (impl *serviceImpl) Propose(ctx gorums.ServerCtx, proposal *hotstuffpb.Prop proposeMsg := hotstuffpb.ProposalFromProto(proposal) proposeMsg.ID = id - impl.srv.mods.EventLoop().AddEvent(proposeMsg) + impl.srv.eventLoop.AddEvent(proposeMsg) } // Vote handles an incoming vote message. func (impl *serviceImpl) Vote(ctx gorums.ServerCtx, cert *hotstuffpb.PartialCert) { - id, err := GetPeerIDFromContext(ctx, impl.srv.mods.Configuration()) + id, err := GetPeerIDFromContext(ctx, impl.srv.configuration) if err != nil { - impl.srv.mods.Logger().Infof("Failed to get client ID: %v", err) + impl.srv.logger.Infof("Failed to get client ID: %v", err) return } - impl.srv.mods.EventLoop().AddEvent(hotstuff.VoteMsg{ + impl.srv.eventLoop.AddEvent(hotstuff.VoteMsg{ ID: id, PartialCert: hotstuffpb.PartialCertFromProto(cert), }) @@ -152,13 +163,13 @@ func (impl *serviceImpl) Vote(ctx gorums.ServerCtx, cert *hotstuffpb.PartialCert // NewView handles the leader's response to receiving a NewView rpc from a replica. func (impl *serviceImpl) NewView(ctx gorums.ServerCtx, msg *hotstuffpb.SyncInfo) { - id, err := GetPeerIDFromContext(ctx, impl.srv.mods.Configuration()) + id, err := GetPeerIDFromContext(ctx, impl.srv.configuration) if err != nil { - impl.srv.mods.Logger().Infof("Failed to get client ID: %v", err) + impl.srv.logger.Infof("Failed to get client ID: %v", err) return } - impl.srv.mods.EventLoop().AddEvent(hotstuff.NewViewMsg{ + impl.srv.eventLoop.AddEvent(hotstuff.NewViewMsg{ ID: id, SyncInfo: hotstuffpb.SyncInfoFromProto(msg), }) @@ -169,12 +180,12 @@ func (impl *serviceImpl) Fetch(ctx gorums.ServerCtx, pb *hotstuffpb.BlockHash) ( var hash hotstuff.Hash copy(hash[:], pb.GetHash()) - block, ok := impl.srv.mods.BlockChain().LocalGet(hash) + block, ok := impl.srv.blockChain.LocalGet(hash) if !ok { return nil, status.Errorf(codes.NotFound, "requested block was not found") } - impl.srv.mods.Logger().Debugf("OnFetch: %.8s", hash) + impl.srv.logger.Debugf("OnFetch: %.8s", hash) return hotstuffpb.BlockToProto(block), nil } @@ -183,11 +194,11 @@ func (impl *serviceImpl) Fetch(ctx gorums.ServerCtx, pb *hotstuffpb.BlockHash) ( func (impl *serviceImpl) Timeout(ctx gorums.ServerCtx, msg *hotstuffpb.TimeoutMsg) { var err error timeoutMsg := hotstuffpb.TimeoutMsgFromProto(msg) - timeoutMsg.ID, err = GetPeerIDFromContext(ctx, impl.srv.mods.Configuration()) + timeoutMsg.ID, err = GetPeerIDFromContext(ctx, impl.srv.configuration) if err != nil { - impl.srv.mods.Logger().Infof("Could not get ID of replica: %v", err) + impl.srv.logger.Infof("Could not get ID of replica: %v", err) } - impl.srv.mods.EventLoop().AddEvent(timeoutMsg) + impl.srv.eventLoop.AddEvent(timeoutMsg) } type replicaConnected struct { diff --git a/block.go b/block.go index 062c9beb..e731e087 100644 --- a/block.go +++ b/block.go @@ -4,6 +4,9 @@ import ( "crypto/sha256" "encoding/binary" "fmt" + "io" + + "github.com/relab/hotstuff/util" ) // Block contains a propsed "command", metadata for the protocol, and a link to the "parent" block. @@ -26,8 +29,13 @@ func NewBlock(parent Hash, cert QuorumCert, cmd Command, view View, proposer ID) view: view, proposer: proposer, } + hasher := sha256.New() + _, err := b.WriteTo(hasher) + if err != nil { + panic("unexpected error: " + err.Error()) + } // cache the hash immediately because it is too racy to do it in Hash() - b.hash = sha256.Sum256(b.ToBytes()) + hasher.Sum(b.hash[:0]) return b } @@ -72,16 +80,19 @@ func (b *Block) View() View { return b.view } -// ToBytes returns the raw byte form of the Block, to be used for hashing, etc. -func (b *Block) ToBytes() []byte { - buf := b.parent[:] +// WriteTo writes the block data to the writer. +func (b *Block) WriteTo(writer io.Writer) (n int64, err error) { var proposerBuf [4]byte binary.LittleEndian.PutUint32(proposerBuf[:], uint32(b.proposer)) - buf = append(buf, proposerBuf[:]...) + var viewBuf [8]byte binary.LittleEndian.PutUint64(viewBuf[:], uint64(b.view)) - buf = append(buf, viewBuf[:]...) - buf = append(buf, []byte(b.cmd)...) - buf = append(buf, b.cert.ToBytes()...) - return buf + + return util.WriteAllTo( + writer, + b.parent[:], + proposerBuf[:], + b.Command(), + b.cert, + ) } diff --git a/blockchain/blockchain.go b/blockchain/blockchain.go index 3ec41ba8..cd9c3b2a 100644 --- a/blockchain/blockchain.go +++ b/blockchain/blockchain.go @@ -3,15 +3,23 @@ package blockchain import ( "context" + "sync" + "github.com/relab/hotstuff" + "github.com/relab/hotstuff/eventloop" + "github.com/relab/hotstuff/logging" "github.com/relab/hotstuff/modules" - "sync" + "github.com/relab/hotstuff/synchronizer" ) // blockChain stores a limited amount of blocks in a map. // blocks are evicted in LRU order. type blockChain struct { - mods *modules.ConsensusCore + configuration modules.Configuration + consensus modules.Consensus + eventLoop *eventloop.EventLoop + logger logging.Logger + mut sync.Mutex pruneHeight hotstuff.View blocks map[hotstuff.Hash]*hotstuff.Block @@ -19,10 +27,13 @@ type blockChain struct { pendingFetch map[hotstuff.Hash]context.CancelFunc // allows a pending fetch operation to be cancelled } -// InitModule gives the module a reference to the ConsensusCore object. -// It also allows the module to set module options using the OptionsBuilder. -func (chain *blockChain) InitModule(mods *modules.ConsensusCore, _ *modules.OptionsBuilder) { - chain.mods = mods +func (chain *blockChain) InitModule(mods *modules.Core) { + mods.GetAll( + &chain.configuration, + &chain.consensus, + &chain.eventLoop, + &chain.logger, + ) } // New creates a new blockChain with a maximum size. @@ -79,12 +90,12 @@ func (chain *blockChain) Get(hash hotstuff.Hash) (block *hotstuff.Block, ok bool goto done } - ctx, cancel = context.WithCancel(chain.mods.Synchronizer().ViewContext()) + ctx, cancel = synchronizer.TimeoutContext(chain.eventLoop.Context(), chain.eventLoop) chain.pendingFetch[hash] = cancel chain.mut.Unlock() - chain.mods.Logger().Debugf("Attempting to fetch block: %.8s", hash) - block, ok = chain.mods.Configuration().Fetch(ctx, hash) + chain.logger.Debugf("Attempting to fetch block: %.8s", hash) + block, ok = chain.configuration.Fetch(ctx, hash) chain.mut.Lock() delete(chain.pendingFetch, hash) @@ -94,13 +105,13 @@ func (chain *blockChain) Get(hash hotstuff.Hash) (block *hotstuff.Block, ok bool goto done } - chain.mods.Logger().Debugf("Successfully fetched block: %.8s", hash) + chain.logger.Debugf("Successfully fetched block: %.8s", hash) chain.blocks[hash] = block chain.blockAtHeight[block.View()] = block done: - defer chain.mut.Unlock() + chain.mut.Unlock() if !ok { return nil, false @@ -123,7 +134,7 @@ func (chain *blockChain) PruneToHeight(height hotstuff.View) (forkedBlocks []*ho chain.mut.Lock() defer chain.mut.Unlock() - committedHeight := chain.mods.Consensus().CommittedBlock().View() + committedHeight := chain.consensus.CommittedBlock().View() committedViews := make(map[hotstuff.View]bool) committedViews[committedHeight] = true for h := committedHeight; h >= chain.pruneHeight; { @@ -143,7 +154,7 @@ func (chain *blockChain) PruneToHeight(height hotstuff.View) (forkedBlocks []*ho if !committedViews[h] { block, ok := chain.blockAtHeight[h] if ok { - chain.mods.Logger().Debugf("PruneToHeight: found forked block: %v", block) + chain.logger.Debugf("PruneToHeight: found forked block: %v", block) forkedBlocks = append(forkedBlocks, block) } } diff --git a/client/client.go b/client/client.go index 6b909be4..1bcdf0e6 100644 --- a/client/client.go +++ b/client/client.go @@ -15,7 +15,9 @@ import ( "github.com/relab/gorums" "github.com/relab/hotstuff" "github.com/relab/hotstuff/backend" + "github.com/relab/hotstuff/eventloop" "github.com/relab/hotstuff/internal/proto/clientpb" + "github.com/relab/hotstuff/logging" "github.com/relab/hotstuff/modules" "golang.org/x/time/rate" "google.golang.org/grpc" @@ -58,8 +60,11 @@ type Config struct { // Client is a hotstuff client. type Client struct { + eventLoop *eventloop.EventLoop + logger logging.Logger + opts *modules.Options + mut sync.Mutex - mods *modules.Core mgr *clientpb.Manager gorumsConfig *clientpb.Configuration payloadSize uint32 @@ -74,12 +79,18 @@ type Client struct { timeout time.Duration } -// New returns a new Client. -func New(conf Config, builder modules.CoreBuilder) (client *Client) { - mods := builder.Build() +// InitModule initializes the client. +func (c *Client) InitModule(mods *modules.Core) { + mods.GetAll( + &c.eventLoop, + &c.logger, + &c.opts, + ) +} +// New returns a new Client. +func New(conf Config, builder modules.Builder) (client *Client) { client = &Client{ - mods: mods, pendingCmds: make(chan pendingCmd, conf.MaxConcurrent), highestCommitted: 1, done: make(chan struct{}), @@ -91,6 +102,10 @@ func New(conf Config, builder modules.CoreBuilder) (client *Client) { timeout: conf.Timeout, } + builder.Add(client) + + builder.Build() + grpcOpts := []grpc.DialOption{grpc.WithBlock()} var creds credentials.TransportCredentials @@ -133,10 +148,10 @@ func (c *Client) Run(ctx context.Context) { eventLoopDone := make(chan struct{}) go func() { - c.mods.EventLoop().Run(ctx) + c.eventLoop.Run(ctx) close(eventLoopDone) }() - c.mods.Logger().Info("Starting to send commands") + c.logger.Info("Starting to send commands") commandStatsChan := make(chan stats) // start the command handler @@ -147,12 +162,12 @@ func (c *Client) Run(ctx context.Context) { err := c.sendCommands(ctx) if err != nil && !errors.Is(err, io.EOF) { - c.mods.Logger().Panicf("Failed to send commands: %v", err) + c.logger.Panicf("Failed to send commands: %v", err) } c.close() commandStats := <-commandStatsChan - c.mods.Logger().Infof( + c.logger.Infof( "Done sending commands (executed: %d, failed: %d, timeouts: %d)", commandStats.executed, commandStats.failed, commandStats.timeout, ) @@ -177,7 +192,7 @@ func (c *Client) close() { c.mgr.Close() err := c.reader.Close() if err != nil { - c.mods.Logger().Warn("Failed to close reader: ", err) + c.logger.Warn("Failed to close reader: ", err) } } @@ -222,11 +237,11 @@ loop: return err } else if err == io.EOF && n == 0 && lastCommand > num { lastCommand = num - c.mods.Logger().Info("Reached end of file. Sending empty commands until last command is executed...") + c.logger.Info("Reached end of file. Sending empty commands until last command is executed...") } cmd := &clientpb.Command{ - ClientID: uint32(c.mods.ID()), + ClientID: uint32(c.opts.ID()), SequenceNumber: num, Data: data[:n], } @@ -243,7 +258,7 @@ loop: } if num%100 == 0 { - c.mods.Logger().Infof("%d commands sent", num) + c.logger.Infof("%d commands sent", num) } } @@ -271,10 +286,10 @@ func (c *Client) handleCommands(ctx context.Context) (executed, failed, timeout if err != nil { qcError, ok := err.(gorums.QuorumCallError) if ok && qcError.Reason == context.DeadlineExceeded.Error() { - c.mods.Logger().Debug("Command timed out.") + c.logger.Debug("Command timed out.") timeout++ } else if !ok || qcError.Reason != context.Canceled.Error() { - c.mods.Logger().Debugf("Did not get enough replies for command: %v\n", err) + c.logger.Debugf("Did not get enough replies for command: %v\n", err) failed++ } } else { @@ -287,7 +302,7 @@ func (c *Client) handleCommands(ctx context.Context) (executed, failed, timeout c.mut.Unlock() duration := time.Since(cmd.sendTime) - c.mods.EventLoop().AddEvent(LatencyMeasurementEvent{Latency: duration}) + c.eventLoop.AddEvent(LatencyMeasurementEvent{Latency: duration}) } } diff --git a/consensus/byzantine/byzantine.go b/consensus/byzantine/byzantine.go index b8ebb330..0d3bafe1 100644 --- a/consensus/byzantine/byzantine.go +++ b/consensus/byzantine/byzantine.go @@ -22,11 +22,9 @@ type silence struct { consensus.Rules } -// InitModule gives the module a reference to the ConsensusCore object. -// It also allows the module to set module options using the OptionsBuilder. -func (s *silence) InitModule(mods *modules.ConsensusCore, opts *modules.OptionsBuilder) { - if mod, ok := s.Rules.(modules.ConsensusModule); ok { - mod.InitModule(mods, opts) +func (s *silence) InitModule(mods *modules.Core) { + if mod, ok := s.Rules.(modules.Module); ok { + mod.InitModule(mods) } } @@ -45,40 +43,45 @@ func NewSilence(c consensus.Rules) consensus.Rules { } type fork struct { - mods *modules.ConsensusCore + blockChain modules.BlockChain + synchronizer modules.Synchronizer + opts *modules.Options consensus.Rules } -// InitModule gives the module a reference to the ConsensusCore object. -// It also allows the module to set module options using the OptionsBuilder. -func (f *fork) InitModule(mods *modules.ConsensusCore, opts *modules.OptionsBuilder) { - f.mods = mods - if mod, ok := f.Rules.(modules.ConsensusModule); ok { - mod.InitModule(mods, opts) +func (f *fork) InitModule(mods *modules.Core) { + mods.GetAll( + &f.blockChain, + &f.synchronizer, + &f.opts, + ) + + if mod, ok := f.Rules.(modules.Module); ok { + mod.InitModule(mods) } } func (f *fork) ProposeRule(cert hotstuff.SyncInfo, cmd hotstuff.Command) (proposal hotstuff.ProposeMsg, ok bool) { - parent, ok := f.mods.BlockChain().Get(f.mods.Synchronizer().LeafBlock().Parent()) + parent, ok := f.blockChain.Get(f.synchronizer.LeafBlock().Parent()) if !ok { return proposal, false } - grandparent, ok := f.mods.BlockChain().Get(parent.Hash()) + grandparent, ok := f.blockChain.Get(parent.Hash()) if !ok { return proposal, false } proposal = hotstuff.ProposeMsg{ - ID: f.mods.ID(), + ID: f.opts.ID(), Block: hotstuff.NewBlock( grandparent.Hash(), grandparent.QuorumCert(), cmd, - f.mods.Synchronizer().View(), - f.mods.ID(), + f.synchronizer.View(), + f.opts.ID(), ), } - if aggQC, ok := cert.AggQC(); f.mods.Options().ShouldUseAggQC() && ok { + if aggQC, ok := cert.AggQC(); f.opts.ShouldUseAggQC() && ok { proposal.AggregateQC = &aggQC } return proposal, true diff --git a/consensus/chainedhotstuff/chainedhotstuff.go b/consensus/chainedhotstuff/chainedhotstuff.go index f9a5d639..3c186b66 100644 --- a/consensus/chainedhotstuff/chainedhotstuff.go +++ b/consensus/chainedhotstuff/chainedhotstuff.go @@ -4,6 +4,7 @@ package chainedhotstuff import ( "github.com/relab/hotstuff" "github.com/relab/hotstuff/consensus" + "github.com/relab/hotstuff/logging" "github.com/relab/hotstuff/modules" ) @@ -13,7 +14,8 @@ func init() { // ChainedHotStuff implements the pipelined three-phase HotStuff protocol. type ChainedHotStuff struct { - mods *modules.ConsensusCore + blockChain modules.BlockChain + logger logging.Logger // protocol variables @@ -27,17 +29,16 @@ func New() consensus.Rules { } } -// InitModule gives the module a reference to the ConsensusCore object. -// It also allows the module to set module options using the OptionsBuilder. -func (hs *ChainedHotStuff) InitModule(mods *modules.ConsensusCore, _ *modules.OptionsBuilder) { - hs.mods = mods +// InitModule initializes the module. +func (hs *ChainedHotStuff) InitModule(mods *modules.Core) { + mods.GetAll(&hs.blockChain, &hs.logger) } func (hs *ChainedHotStuff) qcRef(qc hotstuff.QuorumCert) (*hotstuff.Block, bool) { if (hotstuff.Hash{}) == qc.BlockHash() { return nil, false } - return hs.mods.BlockChain().Get(qc.BlockHash()) + return hs.blockChain.Get(qc.BlockHash()) } // CommitRule decides whether an ancestor of the block should be committed. @@ -49,7 +50,7 @@ func (hs *ChainedHotStuff) CommitRule(block *hotstuff.Block) *hotstuff.Block { // Note that we do not call UpdateHighQC here. // This is done through AdvanceView, which the Consensus implementation will call. - hs.mods.Logger().Debug("PRE_COMMIT: ", block1) + hs.logger.Debug("PRE_COMMIT: ", block1) block2, ok := hs.qcRef(block1.QuorumCert()) if !ok { @@ -57,7 +58,7 @@ func (hs *ChainedHotStuff) CommitRule(block *hotstuff.Block) *hotstuff.Block { } if block2.View() > hs.bLock.View() { - hs.mods.Logger().Debug("COMMIT: ", block2) + hs.logger.Debug("COMMIT: ", block2) hs.bLock = block2 } @@ -67,7 +68,7 @@ func (hs *ChainedHotStuff) CommitRule(block *hotstuff.Block) *hotstuff.Block { } if block1.Parent() == block2.Hash() && block2.Parent() == block3.Hash() { - hs.mods.Logger().Debug("DECIDE: ", block3) + hs.logger.Debug("DECIDE: ", block3) return block3 } @@ -78,18 +79,18 @@ func (hs *ChainedHotStuff) CommitRule(block *hotstuff.Block) *hotstuff.Block { func (hs *ChainedHotStuff) VoteRule(proposal hotstuff.ProposeMsg) bool { block := proposal.Block - qcBlock, haveQCBlock := hs.mods.BlockChain().Get(block.QuorumCert().BlockHash()) + qcBlock, haveQCBlock := hs.blockChain.Get(block.QuorumCert().BlockHash()) safe := false if haveQCBlock && qcBlock.View() > hs.bLock.View() { safe = true } else { - hs.mods.Logger().Debug("OnPropose: liveness condition failed") + hs.logger.Debug("OnPropose: liveness condition failed") // check if this block extends bLock - if hs.mods.BlockChain().Extends(block, hs.bLock) { + if hs.blockChain.Extends(block, hs.bLock) { safe = true } else { - hs.mods.Logger().Debug("OnPropose: safety condition failed") + hs.logger.Debug("OnPropose: safety condition failed") } } diff --git a/consensus/consensus.go b/consensus/consensus.go index 1367f509..fc25d90d 100644 --- a/consensus/consensus.go +++ b/consensus/consensus.go @@ -5,7 +5,10 @@ import ( "sync" "github.com/relab/hotstuff" + "github.com/relab/hotstuff/eventloop" + "github.com/relab/hotstuff/logging" "github.com/relab/hotstuff/modules" + "github.com/relab/hotstuff/synchronizer" ) // Rules is the minimum interface that a consensus implementations must implement. @@ -34,7 +37,21 @@ type ProposeRuler interface { // for implementations of the ConsensusImpl interface. type consensusBase struct { impl Rules - mods *modules.ConsensusCore + + acceptor modules.Acceptor + blockChain modules.BlockChain + commandQueue modules.CommandQueue + configuration modules.Configuration + crypto modules.Crypto + eventLoop *eventloop.EventLoop + executor modules.ExecutorExt + forkHandler modules.ForkHandlerExt + leaderRotation modules.LeaderRotation + logger logging.Logger + opts *modules.Options + synchronizer modules.Synchronizer + + handel modules.Handel lastVote hotstuff.View @@ -51,22 +68,40 @@ func New(impl Rules) modules.Consensus { } } +// InitModule initializes the module. +func (cs *consensusBase) InitModule(mods *modules.Core) { + mods.GetAll( + &cs.acceptor, + &cs.blockChain, + &cs.commandQueue, + &cs.configuration, + &cs.crypto, + &cs.eventLoop, + &cs.executor, + &cs.forkHandler, + &cs.leaderRotation, + &cs.logger, + &cs.opts, + &cs.synchronizer, + ) + + mods.TryGet(&cs.handel) + + if mod, ok := cs.impl.(modules.Module); ok { + mod.InitModule(mods) + } + + cs.eventLoop.RegisterHandler(hotstuff.ProposeMsg{}, func(event any) { + cs.OnPropose(event.(hotstuff.ProposeMsg)) + }) +} + func (cs *consensusBase) CommittedBlock() *hotstuff.Block { cs.mut.Lock() defer cs.mut.Unlock() return cs.bExec } -func (cs *consensusBase) InitModule(mods *modules.ConsensusCore, opts *modules.OptionsBuilder) { - cs.mods = mods - if mod, ok := cs.impl.(modules.ConsensusModule); ok { - mod.InitModule(mods, opts) - } - cs.mods.EventLoop().RegisterHandler(hotstuff.ProposeMsg{}, func(event any) { - cs.OnPropose(event.(hotstuff.ProposeMsg)) - }) -} - // StopVoting ensures that no voting happens in a view earlier than `view`. func (cs *consensusBase) StopVoting(view hotstuff.View) { if cs.lastVote < view { @@ -76,21 +111,24 @@ func (cs *consensusBase) StopVoting(view hotstuff.View) { // Propose creates a new proposal. func (cs *consensusBase) Propose(cert hotstuff.SyncInfo) { - cs.mods.Logger().Debug("Propose") + cs.logger.Debug("Propose") qc, ok := cert.QC() if ok { // tell the acceptor that the previous proposal succeeded. - if qcBlock, ok := cs.mods.BlockChain().Get(qc.BlockHash()); ok { - cs.mods.Acceptor().Proposed(qcBlock.Command()) + if qcBlock, ok := cs.blockChain.Get(qc.BlockHash()); ok { + cs.acceptor.Proposed(qcBlock.Command()) } else { - cs.mods.Logger().Errorf("Could not find block for QC: %s", qc) + cs.logger.Errorf("Could not find block for QC: %s", qc) } } - cmd, ok := cs.mods.CommandQueue().Get(cs.mods.Synchronizer().ViewContext()) + ctx, cancel := synchronizer.TimeoutContext(cs.eventLoop.Context(), cs.eventLoop) + defer cancel() + + cmd, ok := cs.commandQueue.Get(ctx) if !ok { - cs.mods.Logger().Debug("Propose: No command") + cs.logger.Debug("Propose: No command") return } @@ -98,124 +136,115 @@ func (cs *consensusBase) Propose(cert hotstuff.SyncInfo) { if proposer, ok := cs.impl.(ProposeRuler); ok { proposal, ok = proposer.ProposeRule(cert, cmd) if !ok { - cs.mods.Logger().Debug("Propose: No block") + cs.logger.Debug("Propose: No block") return } } else { proposal = hotstuff.ProposeMsg{ - ID: cs.mods.ID(), + ID: cs.opts.ID(), Block: hotstuff.NewBlock( - cs.mods.Synchronizer().LeafBlock().Hash(), + cs.synchronizer.LeafBlock().Hash(), qc, cmd, - cs.mods.Synchronizer().View(), - cs.mods.ID(), + cs.synchronizer.View(), + cs.opts.ID(), ), } - if aggQC, ok := cert.AggQC(); ok && cs.mods.Options().ShouldUseAggQC() { + if aggQC, ok := cert.AggQC(); ok && cs.opts.ShouldUseAggQC() { proposal.AggregateQC = &aggQC } } - cs.mods.BlockChain().Store(proposal.Block) + cs.blockChain.Store(proposal.Block) - cs.mods.Configuration().Propose(proposal) + cs.configuration.Propose(proposal) // self vote cs.OnPropose(proposal) } func (cs *consensusBase) OnPropose(proposal hotstuff.ProposeMsg) { //nolint:gocyclo // TODO: extract parts of this method into helper functions maybe? - cs.mods.Logger().Debugf("OnPropose: %v", proposal.Block) + cs.logger.Debugf("OnPropose: %v", proposal.Block) block := proposal.Block - if cs.mods.Options().ShouldUseAggQC() && proposal.AggregateQC != nil { - highQC, ok := cs.mods.Crypto().VerifyAggregateQC(*proposal.AggregateQC) + if cs.opts.ShouldUseAggQC() && proposal.AggregateQC != nil { + highQC, ok := cs.crypto.VerifyAggregateQC(*proposal.AggregateQC) if !ok { - cs.mods.Logger().Warn("OnPropose: failed to verify aggregate QC") + cs.logger.Warn("OnPropose: failed to verify aggregate QC") return } // NOTE: for simplicity, we require that the highQC found in the AggregateQC equals the QC embedded in the block. if !block.QuorumCert().Equals(highQC) { - cs.mods.Logger().Warn("OnPropose: block QC does not equal highQC") + cs.logger.Warn("OnPropose: block QC does not equal highQC") return } } - if !cs.mods.Crypto().VerifyQuorumCert(block.QuorumCert()) { - cs.mods.Logger().Info("OnPropose: invalid QC") + if !cs.crypto.VerifyQuorumCert(block.QuorumCert()) { + cs.logger.Info("OnPropose: invalid QC") return } // ensure the block came from the leader. - if proposal.ID != cs.mods.LeaderRotation().GetLeader(block.View()) { - cs.mods.Logger().Info("OnPropose: block was not proposed by the expected leader") + if proposal.ID != cs.leaderRotation.GetLeader(block.View()) { + cs.logger.Info("OnPropose: block was not proposed by the expected leader") return } if !cs.impl.VoteRule(proposal) { - cs.mods.Logger().Info("OnPropose: Block not voted for") + cs.logger.Info("OnPropose: Block not voted for") return } - if qcBlock, ok := cs.mods.BlockChain().Get(block.QuorumCert().BlockHash()); ok { - cs.mods.Acceptor().Proposed(qcBlock.Command()) + if qcBlock, ok := cs.blockChain.Get(block.QuorumCert().BlockHash()); ok { + cs.acceptor.Proposed(qcBlock.Command()) } else { - cs.mods.Logger().Info("OnPropose: Failed to fetch qcBlock") + cs.logger.Info("OnPropose: Failed to fetch qcBlock") } - if !cs.mods.Acceptor().Accept(block.Command()) { - cs.mods.Logger().Info("OnPropose: command not accepted") + if !cs.acceptor.Accept(block.Command()) { + cs.logger.Info("OnPropose: command not accepted") return } // block is safe and was accepted - cs.mods.BlockChain().Store(block) + cs.blockChain.Store(block) - didAdvanceView := false - // we defer the following in order to speed up voting - defer func() { - if b := cs.impl.CommitRule(block); b != nil { - cs.commit(b) - } - if !didAdvanceView { - cs.mods.Synchronizer().AdvanceView(hotstuff.NewSyncInfo().WithQC(block.QuorumCert())) - } - }() + if b := cs.impl.CommitRule(block); b != nil { + cs.commit(b) + } + cs.synchronizer.AdvanceView(hotstuff.NewSyncInfo().WithQC(block.QuorumCert())) if block.View() <= cs.lastVote { - cs.mods.Logger().Info("OnPropose: block view too old") + cs.logger.Info("OnPropose: block view too old") return } - pc, err := cs.mods.Crypto().CreatePartialCert(block) + pc, err := cs.crypto.CreatePartialCert(block) if err != nil { - cs.mods.Logger().Error("OnPropose: failed to sign block: ", err) + cs.logger.Error("OnPropose: failed to sign block: ", err) return } cs.lastVote = block.View() - if cs.mods.Options().ShouldUseHandel() { - // Need to call advanceview such that the view context will be fresh. - // TODO: we could instead - cs.mods.Synchronizer().AdvanceView(hotstuff.NewSyncInfo().WithQC(block.QuorumCert())) - didAdvanceView = true - cs.mods.Handel().Begin(pc) + if cs.handel != nil { + // let Handel handle the voting + cs.handel.Begin(pc) return } - leaderID := cs.mods.LeaderRotation().GetLeader(cs.lastVote + 1) - if leaderID == cs.mods.ID() { - cs.mods.EventLoop().AddEvent(hotstuff.VoteMsg{ID: cs.mods.ID(), PartialCert: pc}) + leaderID := cs.leaderRotation.GetLeader(cs.lastVote + 1) + if leaderID == cs.opts.ID() { + cs.eventLoop.AddEvent(hotstuff.VoteMsg{ID: cs.opts.ID(), PartialCert: pc}) return } - leader, ok := cs.mods.Configuration().Replica(leaderID) + leader, ok := cs.configuration.Replica(leaderID) if !ok { - cs.mods.Logger().Warnf("Replica with ID %d was not found!", leaderID) + cs.logger.Warnf("Replica with ID %d was not found!", leaderID) return } @@ -229,14 +258,14 @@ func (cs *consensusBase) commit(block *hotstuff.Block) { cs.mut.Unlock() if err != nil { - cs.mods.Logger().Warnf("failed to commit: %v", err) + cs.logger.Warnf("failed to commit: %v", err) return } // prune the blockchain and handle forked blocks - forkedBlocks := cs.mods.BlockChain().PruneToHeight(block.View()) + forkedBlocks := cs.blockChain.PruneToHeight(block.View()) for _, block := range forkedBlocks { - cs.mods.ForkHandler().Fork(block) + cs.forkHandler.Fork(block) } } @@ -245,7 +274,7 @@ func (cs *consensusBase) commitInner(block *hotstuff.Block) error { if cs.bExec.View() >= block.View() { return nil } - if parent, ok := cs.mods.BlockChain().Get(block.Parent()); ok { + if parent, ok := cs.blockChain.Get(block.Parent()); ok { err := cs.commitInner(parent) if err != nil { return err @@ -253,8 +282,8 @@ func (cs *consensusBase) commitInner(block *hotstuff.Block) error { } else { return fmt.Errorf("failed to locate block: %s", block.Parent()) } - cs.mods.Logger().Debug("EXEC: ", block) - cs.mods.Executor().Exec(block) + cs.logger.Debug("EXEC: ", block) + cs.executor.Exec(block) cs.bExec = block return nil } diff --git a/consensus/consensus_test.go b/consensus/consensus_test.go index 59d8c948..8172bc15 100644 --- a/consensus/consensus_test.go +++ b/consensus/consensus_test.go @@ -6,8 +6,10 @@ import ( "github.com/golang/mock/gomock" "github.com/relab/hotstuff" + "github.com/relab/hotstuff/eventloop" "github.com/relab/hotstuff/internal/mocks" "github.com/relab/hotstuff/internal/testutil" + "github.com/relab/hotstuff/modules" "github.com/relab/hotstuff/synchronizer" ) @@ -17,15 +19,22 @@ func TestVote(t *testing.T) { ctrl := gomock.NewController(t) bl := testutil.CreateBuilders(t, ctrl, n) cs := mocks.NewMockConsensus(ctrl) - bl[0].Register(synchronizer.New(testutil.FixedTimeout(1000)), cs) + bl[0].Add(synchronizer.New(testutil.FixedTimeout(1000)), cs) hl := bl.Build() hs := hl[0] + var ( + eventLoop *eventloop.EventLoop + blockChain modules.BlockChain + ) + + hs.GetAll(&eventLoop, &blockChain) + cs.EXPECT().Propose(gomock.AssignableToTypeOf(hotstuff.NewSyncInfo())) ok := false ctx, cancel := context.WithCancel(context.Background()) - hs.EventLoop().RegisterObserver(hotstuff.NewViewMsg{}, func(event any) { + eventLoop.RegisterObserver(hotstuff.NewViewMsg{}, func(event any) { ok = true cancel() }) @@ -35,17 +44,17 @@ func TestVote(t *testing.T) { hotstuff.NewQuorumCert(nil, 1, hotstuff.GetGenesis().Hash()), "test", 1, 1, ) - hs.BlockChain().Store(b.Block) + blockChain.Store(b.Block) for i, signer := range hl.Signers() { pc, err := signer.CreatePartialCert(b.Block) if err != nil { t.Fatalf("Failed to create partial certificate: %v", err) } - hs.EventLoop().AddEvent(hotstuff.VoteMsg{ID: hotstuff.ID(i + 1), PartialCert: pc}) + eventLoop.AddEvent(hotstuff.VoteMsg{ID: hotstuff.ID(i + 1), PartialCert: pc}) } - hs.Run(ctx) + eventLoop.Run(ctx) if !ok { t.Error("No new view event happened") diff --git a/consensus/fasthotstuff/fasthotstuff.go b/consensus/fasthotstuff/fasthotstuff.go index 9c5dc416..484b1d55 100644 --- a/consensus/fasthotstuff/fasthotstuff.go +++ b/consensus/fasthotstuff/fasthotstuff.go @@ -4,6 +4,7 @@ package fasthotstuff import ( "github.com/relab/hotstuff" "github.com/relab/hotstuff/consensus" + "github.com/relab/hotstuff/logging" "github.com/relab/hotstuff/modules" ) @@ -13,7 +14,9 @@ func init() { // FastHotStuff is an implementation of the Fast-HotStuff protocol. type FastHotStuff struct { - mods *modules.ConsensusCore + blockChain modules.BlockChain + logger logging.Logger + synchronizer modules.Synchronizer } // New returns a new FastHotStuff instance. @@ -21,10 +24,12 @@ func New() consensus.Rules { return &FastHotStuff{} } -// InitModule gives the module a reference to the ConsensusCore object. -// It also allows the module to set module options using the OptionsBuilder. -func (fhs *FastHotStuff) InitModule(mods *modules.ConsensusCore, opts *modules.OptionsBuilder) { - fhs.mods = mods +// InitModule initializes the module. +func (fhs *FastHotStuff) InitModule(mods *modules.Core) { + var opts *modules.Options + + mods.GetAll(&opts, &fhs.blockChain, &fhs.logger, &fhs.synchronizer) + opts.SetShouldUseAggQC() } @@ -32,7 +37,7 @@ func (fhs *FastHotStuff) qcRef(qc hotstuff.QuorumCert) (*hotstuff.Block, bool) { if (hotstuff.Hash{}) == qc.BlockHash() { return nil, false } - return fhs.mods.BlockChain().Get(qc.BlockHash()) + return fhs.blockChain.Get(qc.BlockHash()) } // CommitRule decides whether an ancestor of the block can be committed. @@ -41,14 +46,14 @@ func (fhs *FastHotStuff) CommitRule(block *hotstuff.Block) *hotstuff.Block { if !ok { return nil } - fhs.mods.Logger().Debug("PRECOMMIT: ", parent) + fhs.logger.Debug("PRECOMMIT: ", parent) grandparent, ok := fhs.qcRef(parent.QuorumCert()) if !ok { return nil } if block.Parent() == parent.Hash() && block.View() == parent.View()+1 && parent.Parent() == grandparent.Hash() && parent.View() == grandparent.View()+1 { - fhs.mods.Logger().Debug("COMMIT: ", grandparent) + fhs.logger.Debug("COMMIT: ", grandparent) return grandparent } return nil @@ -59,10 +64,10 @@ func (fhs *FastHotStuff) VoteRule(proposal hotstuff.ProposeMsg) bool { // The base implementation verifies both regular QCs and AggregateQCs, and asserts that the QC embedded in the // block is the same as the highQC found in the aggregateQC. if proposal.AggregateQC != nil { - hqcBlock, ok := fhs.mods.BlockChain().Get(proposal.Block.QuorumCert().BlockHash()) - return ok && fhs.mods.BlockChain().Extends(proposal.Block, hqcBlock) + hqcBlock, ok := fhs.blockChain.Get(proposal.Block.QuorumCert().BlockHash()) + return ok && fhs.blockChain.Extends(proposal.Block, hqcBlock) } - return proposal.Block.View() >= fhs.mods.Synchronizer().View() && + return proposal.Block.View() >= fhs.synchronizer.View() && proposal.Block.View() == proposal.Block.QuorumCert().View()+1 } diff --git a/consensus/simplehotstuff/simplehotstuff.go b/consensus/simplehotstuff/simplehotstuff.go index b451481c..3ef7a6a1 100644 --- a/consensus/simplehotstuff/simplehotstuff.go +++ b/consensus/simplehotstuff/simplehotstuff.go @@ -4,6 +4,7 @@ package simplehotstuff import ( "github.com/relab/hotstuff" "github.com/relab/hotstuff/consensus" + "github.com/relab/hotstuff/logging" "github.com/relab/hotstuff/modules" ) @@ -16,7 +17,9 @@ func init() { // Based on the simplified algorithm described in the paper // "Formal Verification of HotStuff" by Leander Jehl. type SimpleHotStuff struct { - mods *modules.ConsensusCore + blockChain modules.BlockChain + logger logging.Logger + synchronizer modules.Synchronizer locked *hotstuff.Block } @@ -28,10 +31,9 @@ func New() consensus.Rules { } } -// InitModule gives the module a reference to the ConsensusCore object. -// It also allows the module to set module options using the OptionsBuilder. -func (hs *SimpleHotStuff) InitModule(mods *modules.ConsensusCore, _ *modules.OptionsBuilder) { - hs.mods = mods +// InitModule initializes the module. +func (hs *SimpleHotStuff) InitModule(mods *modules.Core) { + mods.GetAll(&hs.blockChain, &hs.logger, &hs.synchronizer) } // VoteRule decides if the replica should vote for the given block. @@ -39,20 +41,20 @@ func (hs *SimpleHotStuff) VoteRule(proposal hotstuff.ProposeMsg) bool { block := proposal.Block // Rule 1: can only vote in increasing rounds - if block.View() < hs.mods.Synchronizer().View() { - hs.mods.Logger().Info("VoteRule: block view too low") + if block.View() < hs.synchronizer.View() { + hs.logger.Info("VoteRule: block view too low") return false } - parent, ok := hs.mods.BlockChain().Get(block.QuorumCert().BlockHash()) + parent, ok := hs.blockChain.Get(block.QuorumCert().BlockHash()) if !ok { - hs.mods.Logger().Info("VoteRule: missing parent block: ", block.QuorumCert().BlockHash()) + hs.logger.Info("VoteRule: missing parent block: ", block.QuorumCert().BlockHash()) return false } // Rule 2: can only vote if parent's view is greater than or equal to locked block's view. if parent.View() < hs.locked.View() { - hs.mods.Logger().Info("OnPropose: parent too old") + hs.logger.Info("OnPropose: parent too old") return false } @@ -62,20 +64,20 @@ func (hs *SimpleHotStuff) VoteRule(proposal hotstuff.ProposeMsg) bool { // CommitRule decides if an ancestor of the block can be committed, and returns the ancestor, otherwise returns nil. func (hs *SimpleHotStuff) CommitRule(block *hotstuff.Block) *hotstuff.Block { // will consider if the great-grandparent of the new block can be committed. - p, ok := hs.mods.BlockChain().Get(block.QuorumCert().BlockHash()) + p, ok := hs.blockChain.Get(block.QuorumCert().BlockHash()) if !ok { return nil } - gp, ok := hs.mods.BlockChain().Get(p.QuorumCert().BlockHash()) + gp, ok := hs.blockChain.Get(p.QuorumCert().BlockHash()) if ok && gp.View() > hs.locked.View() { hs.locked = gp - hs.mods.Logger().Debug("Locked: ", gp) + hs.logger.Debug("Locked: ", gp) } else if !ok { return nil } - ggp, ok := hs.mods.BlockChain().Get(gp.QuorumCert().BlockHash()) + ggp, ok := hs.blockChain.Get(gp.QuorumCert().BlockHash()) // we commit the great-grandparent of the block if its grandchild is certified, // which we already know is true because the new block contains the grandchild's certificate, // and if the great-grandparent's view + 2 equals the grandchild's view. diff --git a/consensus/votingmachine.go b/consensus/votingmachine.go index 6721a811..320113d8 100644 --- a/consensus/votingmachine.go +++ b/consensus/votingmachine.go @@ -4,13 +4,22 @@ import ( "sync" "github.com/relab/hotstuff" + "github.com/relab/hotstuff/eventloop" + "github.com/relab/hotstuff/logging" "github.com/relab/hotstuff/modules" ) // VotingMachine collects votes. type VotingMachine struct { + blockChain modules.BlockChain + configuration modules.Configuration + crypto modules.Crypto + eventLoop *eventloop.EventLoop + logger logging.Logger + synchronizer modules.Synchronizer + opts *modules.Options + mut sync.Mutex - mods *modules.ConsensusCore verifiedVotes map[hotstuff.Hash][]hotstuff.PartialCert // verified votes that could become a QC } @@ -21,17 +30,25 @@ func NewVotingMachine() *VotingMachine { } } -// InitModule gives the module a reference to the ConsensusCore object. -// It also allows the module to set module options using the OptionsBuilder. -func (vm *VotingMachine) InitModule(mods *modules.ConsensusCore, _ *modules.OptionsBuilder) { - vm.mods = mods - vm.mods.EventLoop().RegisterHandler(hotstuff.VoteMsg{}, func(event any) { vm.OnVote(event.(hotstuff.VoteMsg)) }) +// InitModule initializes the VotingMachine. +func (vm *VotingMachine) InitModule(mods *modules.Core) { + mods.GetAll( + &vm.blockChain, + &vm.configuration, + &vm.crypto, + &vm.eventLoop, + &vm.logger, + &vm.synchronizer, + &vm.opts, + ) + + vm.eventLoop.RegisterHandler(hotstuff.VoteMsg{}, func(event any) { vm.OnVote(event.(hotstuff.VoteMsg)) }) } // OnVote handles an incoming vote. func (vm *VotingMachine) OnVote(vote hotstuff.VoteMsg) { cert := vote.PartialCert - vm.mods.Logger().Debugf("OnVote(%d): %.8s", vote.ID, cert.BlockHash()) + vm.logger.Debugf("OnVote(%d): %.8s", vote.ID, cert.BlockHash()) var ( block *hotstuff.Block @@ -40,30 +57,30 @@ func (vm *VotingMachine) OnVote(vote hotstuff.VoteMsg) { if !vote.Deferred { // first, try to get the block from the local cache - block, ok = vm.mods.BlockChain().LocalGet(cert.BlockHash()) + block, ok = vm.blockChain.LocalGet(cert.BlockHash()) if !ok { // if that does not work, we will try to handle this event later. // hopefully, the block has arrived by then. - vm.mods.Logger().Debugf("Local cache miss for block: %.8s", cert.BlockHash()) + vm.logger.Debugf("Local cache miss for block: %.8s", cert.BlockHash()) vote.Deferred = true - vm.mods.EventLoop().DelayUntil(hotstuff.ProposeMsg{}, vote) + vm.eventLoop.DelayUntil(hotstuff.ProposeMsg{}, vote) return } } else { // if the block has not arrived at this point we will try to fetch it. - block, ok = vm.mods.BlockChain().Get(cert.BlockHash()) + block, ok = vm.blockChain.Get(cert.BlockHash()) if !ok { - vm.mods.Logger().Debugf("Could not find block for vote: %.8s.", cert.BlockHash()) + vm.logger.Debugf("Could not find block for vote: %.8s.", cert.BlockHash()) return } } - if block.View() <= vm.mods.Synchronizer().LeafBlock().View() { + if block.View() <= vm.synchronizer.LeafBlock().View() { // too old return } - if vm.mods.Options().ShouldVerifyVotesSync() { + if vm.opts.ShouldVerifyVotesSync() { vm.verifyCert(cert, block) } else { go vm.verifyCert(cert, block) @@ -71,8 +88,8 @@ func (vm *VotingMachine) OnVote(vote hotstuff.VoteMsg) { } func (vm *VotingMachine) verifyCert(cert hotstuff.PartialCert, block *hotstuff.Block) { - if !vm.mods.Crypto().VerifyPartialCert(cert) { - vm.mods.Logger().Info("OnVote: Vote could not be verified!") + if !vm.crypto.VerifyPartialCert(cert) { + vm.logger.Info("OnVote: Vote could not be verified!") return } @@ -83,8 +100,8 @@ func (vm *VotingMachine) verifyCert(cert hotstuff.PartialCert, block *hotstuff.B defer func() { // delete any pending QCs with lower height than bLeaf for k := range vm.verifiedVotes { - if block, ok := vm.mods.BlockChain().LocalGet(k); ok { - if block.View() <= vm.mods.Synchronizer().LeafBlock().View() { + if block, ok := vm.blockChain.LocalGet(k); ok { + if block.View() <= vm.synchronizer.LeafBlock().View() { delete(vm.verifiedVotes, k) } } else { @@ -97,16 +114,16 @@ func (vm *VotingMachine) verifyCert(cert hotstuff.PartialCert, block *hotstuff.B votes = append(votes, cert) vm.verifiedVotes[cert.BlockHash()] = votes - if len(votes) < vm.mods.Configuration().QuorumSize() { + if len(votes) < vm.configuration.QuorumSize() { return } - qc, err := vm.mods.Crypto().CreateQuorumCert(block, votes) + qc, err := vm.crypto.CreateQuorumCert(block, votes) if err != nil { - vm.mods.Logger().Info("OnVote: could not create QC for block: ", err) + vm.logger.Info("OnVote: could not create QC for block: ", err) return } delete(vm.verifiedVotes, cert.BlockHash()) - vm.mods.EventLoop().AddEvent(hotstuff.NewViewMsg{ID: vm.mods.ID(), SyncInfo: hotstuff.NewSyncInfo().WithQC(qc)}) + vm.eventLoop.AddEvent(hotstuff.NewViewMsg{ID: vm.opts.ID(), SyncInfo: hotstuff.NewSyncInfo().WithQC(qc)}) } diff --git a/crypto/bls12/bls12.go b/crypto/bls12/bls12.go index 3cbd903b..98bbc8c1 100644 --- a/crypto/bls12/bls12.go +++ b/crypto/bls12/bls12.go @@ -11,6 +11,7 @@ import ( bls12 "github.com/kilic/bls12-381" "github.com/relab/hotstuff" "github.com/relab/hotstuff/crypto" + "github.com/relab/hotstuff/logging" "github.com/relab/hotstuff/modules" ) @@ -139,7 +140,9 @@ func firstParticipant(participants hotstuff.IDSet) hotstuff.ID { } type bls12Base struct { - mods *modules.ConsensusCore + configuration modules.Configuration + logger logging.Logger + opts *modules.Options mut sync.RWMutex // popCache caches the proof-of-possession results of popVerify for each public key. @@ -153,31 +156,34 @@ func New() modules.CryptoBase { } } -// InitModule gives the module a reference to the ConsensusCore object. +// InitModule gives the module a reference to the Core object. // It also allows the module to set module options using the OptionsBuilder. -func (bls *bls12Base) InitModule(mods *modules.ConsensusCore, opts *modules.OptionsBuilder) { - bls.mods = mods +func (bls *bls12Base) InitModule(mods *modules.Core) { + mods.GetAll( + &bls.configuration, + &bls.logger, + &bls.opts, + ) pop := bls.popProve() b := bls12.NewG2().ToCompressed(pop) - opts.SetConnectionMetadata(popMetadataKey, string(b)) + bls.opts.SetConnectionMetadata(popMetadataKey, string(b)) } func (bls *bls12Base) privateKey() *PrivateKey { - pk := bls.mods.PrivateKey() - return pk.(*PrivateKey) + return bls.opts.PrivateKey().(*PrivateKey) } func (bls *bls12Base) publicKey(id hotstuff.ID) (pubKey *PublicKey, ok bool) { - if replica, ok := bls.mods.Configuration().Replica(id); ok { - if replica.ID() != bls.mods.ID() && !bls.checkPop(replica) { - bls.mods.Logger().Warnf("Invalid POP for replica %d", id) + if replica, ok := bls.configuration.Replica(id); ok { + if replica.ID() != bls.opts.ID() && !bls.checkPop(replica) { + bls.logger.Warnf("Invalid POP for replica %d", id) return nil, false } if pubKey, ok = replica.PublicKey().(*PublicKey); ok { return pubKey, true } - bls.mods.Logger().Errorf("Unsupported public key type: %T", replica.PublicKey()) + bls.logger.Errorf("Unsupported public key type: %T", replica.PublicKey()) } return nil, false } @@ -220,7 +226,7 @@ func (bls *bls12Base) popProve() *bls12.PointG2 { pubKey := bls.privateKey().Public().(*PublicKey) proof, err := bls.coreSign(pubKey.ToBytes(), domainPOP) if err != nil { - bls.mods.Logger().Panicf("Failed to generate proof-of-possession: %v", err) + bls.logger.Panicf("Failed to generate proof-of-possession: %v", err) } return proof } @@ -232,13 +238,13 @@ func (bls *bls12Base) popVerify(pubKey *PublicKey, proof *bls12.PointG2) bool { func (bls *bls12Base) checkPop(replica modules.Replica) (valid bool) { defer func() { if !valid { - bls.mods.Logger().Warnf("Invalid proof-of-possession for replica %d", replica.ID()) + bls.logger.Warnf("Invalid proof-of-possession for replica %d", replica.ID()) } }() popBytes, ok := replica.Metadata()[popMetadataKey] if !ok { - bls.mods.Logger().Warnf("Missing proof-of-possession for replica: %d", replica.ID()) + bls.logger.Warnf("Missing proof-of-possession for replica: %d", replica.ID()) return false } @@ -321,7 +327,7 @@ func (bls *bls12Base) Sign(message []byte) (signature hotstuff.QuorumSignature, return nil, fmt.Errorf("bls12: coreSign failed: %w", err) } bf := crypto.Bitfield{} - bf.Add(bls.mods.ID()) + bf.Add(bls.opts.ID()) return &AggregateSignature{sig: *p, participants: bf}, nil } @@ -349,7 +355,7 @@ func (bls *bls12Base) Combine(signatures ...hotstuff.QuorumSignature) (combined } g2.Add(&agg, &agg, &sig2.sig) } else { - bls.mods.Logger().Panicf("cannot combine incompatible signature type %T (expected %T)", sig1, sig2) + bls.logger.Panicf("cannot combine incompatible signature type %T (expected %T)", sig1, sig2) } } return &AggregateSignature{sig: agg, participants: participants}, nil @@ -359,7 +365,7 @@ func (bls *bls12Base) Combine(signatures ...hotstuff.QuorumSignature) (combined func (bls *bls12Base) Verify(signature hotstuff.QuorumSignature, message []byte) bool { s, ok := signature.(*AggregateSignature) if !ok { - bls.mods.Logger().Panicf("cannot verify signature of incompatible type %T (expected %T)", signature, s) + bls.logger.Panicf("cannot verify signature of incompatible type %T (expected %T)", signature, s) } n := s.Participants().Len() @@ -368,7 +374,7 @@ func (bls *bls12Base) Verify(signature hotstuff.QuorumSignature, message []byte) id := firstParticipant(s.Participants()) pk, ok := bls.publicKey(id) if !ok { - bls.mods.Logger().Warnf("Missing public key for ID %d", id) + bls.logger.Warnf("Missing public key for ID %d", id) return false } return bls.coreVerify(pk, message, &s.sig, domain) @@ -382,7 +388,7 @@ func (bls *bls12Base) Verify(signature hotstuff.QuorumSignature, message []byte) pks = append(pks, pk) return true } - bls.mods.Logger().Warnf("Missing public key for ID %d", id) + bls.logger.Warnf("Missing public key for ID %d", id) return false }) if len(pks) != n { @@ -395,7 +401,7 @@ func (bls *bls12Base) Verify(signature hotstuff.QuorumSignature, message []byte) func (bls *bls12Base) BatchVerify(signature hotstuff.QuorumSignature, batch map[hotstuff.ID][]byte) bool { s, ok := signature.(*AggregateSignature) if !ok { - bls.mods.Logger().Panicf("cannot verify incompatible signature type %T (expected %T)", signature, s) + bls.logger.Panicf("cannot verify incompatible signature type %T (expected %T)", signature, s) } if s.Participants().Len() != len(batch) { @@ -409,7 +415,7 @@ func (bls *bls12Base) BatchVerify(signature hotstuff.QuorumSignature, batch map[ msgs = append(msgs, msg) pk, ok := bls.publicKey(id) if !ok { - bls.mods.Logger().Warnf("Missing public key for ID %d", id) + bls.logger.Warnf("Missing public key for ID %d", id) return false } pks = append(pks, pk) diff --git a/crypto/cache.go b/crypto/cache.go index 90f009e2..ae545ade 100644 --- a/crypto/cache.go +++ b/crypto/cache.go @@ -30,11 +30,11 @@ func NewCache(impl modules.CryptoBase, capacity int) modules.Crypto { }) } -// InitModule gives the module a reference to the ConsensusCore object. +// InitModule gives the module a reference to the Core object. // It also allows the module to set module options using the OptionsBuilder. -func (cache *cache) InitModule(mods *modules.ConsensusCore, cfg *modules.OptionsBuilder) { - if mod, ok := cache.impl.(modules.ConsensusModule); ok { - mod.InitModule(mods, cfg) +func (cache *cache) InitModule(mods *modules.Core) { + if mod, ok := cache.impl.(modules.Module); ok { + mod.InitModule(mods) } } diff --git a/crypto/crypto.go b/crypto/crypto.go index 31b78ebb..e0fa64c7 100644 --- a/crypto/crypto.go +++ b/crypto/crypto.go @@ -2,12 +2,17 @@ package crypto import ( + "bytes" + "github.com/relab/hotstuff" "github.com/relab/hotstuff/modules" + "github.com/relab/hotstuff/util/gpool" ) type crypto struct { - mods *modules.ConsensusCore + blockChain modules.BlockChain + configuration modules.Configuration + modules.CryptoBase } @@ -17,18 +22,35 @@ func New(impl modules.CryptoBase) modules.Crypto { return &crypto{CryptoBase: impl} } -// InitModule gives the module a reference to the ConsensusCore object. +// InitModule gives the module a reference to the Core object. // It also allows the module to set module options using the OptionsBuilder. -func (c *crypto) InitModule(mods *modules.ConsensusCore, cfg *modules.OptionsBuilder) { - c.mods = mods - if mod, ok := c.CryptoBase.(modules.ConsensusModule); ok { - mod.InitModule(mods, cfg) +func (c *crypto) InitModule(mods *modules.Core) { + mods.GetAll( + &c.blockChain, + &c.configuration, + ) + + if mod, ok := c.CryptoBase.(modules.Module); ok { + mod.InitModule(mods) } } +var bufferPool gpool.Pool[bytes.Buffer] + // CreatePartialCert signs a single block and returns the partial certificate. func (c crypto) CreatePartialCert(block *hotstuff.Block) (cert hotstuff.PartialCert, err error) { - sig, err := c.Sign(block.ToBytes()) + buf := bufferPool.Get() + _, err = block.WriteTo(&buf) + if err != nil { + return cert, err + } + + defer func() { + buf.Reset() + bufferPool.Put(buf) + }() + + sig, err := c.Sign(buf.Bytes()) if err != nil { return hotstuff.PartialCert{}, err } @@ -90,11 +112,23 @@ func (c crypto) CreateAggregateQC(view hotstuff.View, timeouts []hotstuff.Timeou // VerifyPartialCert verifies a single partial certificate. func (c crypto) VerifyPartialCert(cert hotstuff.PartialCert) bool { - block, ok := c.mods.BlockChain().Get(cert.BlockHash()) + block, ok := c.blockChain.Get(cert.BlockHash()) if !ok { return false } - return c.Verify(cert.Signature(), block.ToBytes()) + + buf := bufferPool.Get() + _, err := block.WriteTo(&buf) + if err != nil { + return false + } + + defer func() { + buf.Reset() + bufferPool.Put(buf) + }() + + return c.Verify(cert.Signature(), buf.Bytes()) } // VerifyQuorumCert verifies a quorum certificate. @@ -103,14 +137,26 @@ func (c crypto) VerifyQuorumCert(qc hotstuff.QuorumCert) bool { if qc.BlockHash() == hotstuff.GetGenesis().Hash() { return true } - if qc.Signature().Participants().Len() < c.mods.Configuration().QuorumSize() { + if qc.Signature().Participants().Len() < c.configuration.QuorumSize() { return false } - block, ok := c.mods.BlockChain().Get(qc.BlockHash()) + block, ok := c.blockChain.Get(qc.BlockHash()) if !ok { return false } - return c.Verify(qc.Signature(), block.ToBytes()) + + buf := bufferPool.Get() + _, err := block.WriteTo(&buf) + if err != nil { + return false + } + + defer func() { + buf.Reset() + bufferPool.Put(buf) + }() + + return c.Verify(qc.Signature(), buf.Bytes()) } // VerifyTimeoutCert verifies a timeout certificate. @@ -119,7 +165,7 @@ func (c crypto) VerifyTimeoutCert(tc hotstuff.TimeoutCert) bool { if tc.View() == 0 { return true } - if tc.Signature().Participants().Len() < c.mods.Configuration().QuorumSize() { + if tc.Signature().Participants().Len() < c.configuration.QuorumSize() { return false } return c.Verify(tc.Signature(), tc.View().ToBytes()) @@ -139,7 +185,7 @@ func (c crypto) VerifyAggregateQC(aggQC hotstuff.AggregateQC) (highQC hotstuff.Q SyncInfo: hotstuff.NewSyncInfo().WithQC(qc), }.ToBytes() } - if aggQC.Sig().Participants().Len() < c.mods.Configuration().QuorumSize() { + if aggQC.Sig().Participants().Len() < c.configuration.QuorumSize() { return hotstuff.QuorumCert{}, false } // both the batched aggQC signatures and the highQC must be verified diff --git a/crypto/crypto_test.go b/crypto/crypto_test.go index cf355e98..56814a4f 100644 --- a/crypto/crypto_test.go +++ b/crypto/crypto_test.go @@ -1,9 +1,10 @@ package crypto_test import ( - "github.com/relab/hotstuff/modules" "testing" + "github.com/relab/hotstuff/modules" + "github.com/golang/mock/gomock" "github.com/relab/hotstuff" "github.com/relab/hotstuff/crypto" @@ -215,13 +216,20 @@ func newTestData(t *testing.T, ctrl *gomock.Controller, n int, newFunc func() mo bl := testutil.CreateBuilders(t, ctrl, n, testutil.GenerateKeys(t, n, keyFunc)...) for _, builder := range bl { signer := newFunc() - builder.Register(signer) + builder.Add(signer) } hl := bl.Build() - block := createBlock(t, hl[0].Crypto()) + + var signer modules.Crypto + hl[0].Get(&signer) + + block := createBlock(t, signer) for _, mods := range hl { - mods.BlockChain().Store(block) + var blockChain modules.BlockChain + mods.Get(&blockChain) + + blockChain.Store(block) } return testData{ diff --git a/crypto/ecdsa/ecdsa.go b/crypto/ecdsa/ecdsa.go index c61495be..961e0b9b 100644 --- a/crypto/ecdsa/ecdsa.go +++ b/crypto/ecdsa/ecdsa.go @@ -6,11 +6,14 @@ import ( "crypto/rand" "crypto/sha256" "fmt" + "io" "math/big" "github.com/relab/hotstuff" "github.com/relab/hotstuff/crypto" + "github.com/relab/hotstuff/logging" "github.com/relab/hotstuff/modules" + "github.com/relab/hotstuff/util" "golang.org/x/exp/slices" ) @@ -60,6 +63,11 @@ func (sig Signature) ToBytes() []byte { return b } +// WriteTo writes the signature to the writer. +func (sig Signature) WriteTo(writer io.Writer) (n int64, err error) { + return util.WriteAllTo(writer, sig.r, sig.s) +} + // MultiSignature is a set of (partial) signatures. type MultiSignature map[hotstuff.ID]*Signature @@ -87,6 +95,25 @@ func (sig MultiSignature) ToBytes() []byte { return b } +// WriteTo writes the multi signature to the writer. +func (sig MultiSignature) WriteTo(writer io.Writer) (n int64, err error) { + // sort by ID to make it deterministic + order := make([]hotstuff.ID, 0, len(sig)) + for _, signature := range sig { + order = append(order, signature.signer) + } + slices.Sort(order) + var nn int64 + for _, id := range order { + nn, err = sig[id].WriteTo(writer) + n += nn + if err != nil { + return n, err + } + } + return n, nil +} + // Participants returns the IDs of replicas who participated in the threshold signature. func (sig MultiSignature) Participants() hotstuff.IDSet { return sig @@ -132,7 +159,9 @@ var _ hotstuff.QuorumSignature = (*MultiSignature)(nil) var _ hotstuff.IDSet = (*MultiSignature)(nil) type ecdsaBase struct { - mods *modules.ConsensusCore + configuration modules.Configuration + logger logging.Logger + opts *modules.Options } // New returns a new instance of the ECDSA CryptoBase implementation. @@ -141,14 +170,17 @@ func New() modules.CryptoBase { } func (ec *ecdsaBase) getPrivateKey() *ecdsa.PrivateKey { - pk := ec.mods.PrivateKey() - return pk.(*ecdsa.PrivateKey) + return ec.opts.PrivateKey().(*ecdsa.PrivateKey) } -// InitModule gives the module a reference to the ConsensusCore object. +// InitModule gives the module a reference to the Core object. // It also allows the module to set module options using the OptionsBuilder. -func (ec *ecdsaBase) InitModule(mods *modules.ConsensusCore, _ *modules.OptionsBuilder) { - ec.mods = mods +func (ec *ecdsaBase) InitModule(mods *modules.Core) { + mods.GetAll( + &ec.configuration, + &ec.logger, + &ec.opts, + ) } // Sign creates a cryptographic signature of the given message. @@ -158,10 +190,10 @@ func (ec *ecdsaBase) Sign(message []byte) (signature hotstuff.QuorumSignature, e if err != nil { return nil, fmt.Errorf("ecdsa: sign failed: %w", err) } - return MultiSignature{ec.mods.ID(): &Signature{ + return MultiSignature{ec.opts.ID(): &Signature{ r: r, s: s, - signer: ec.mods.ID(), + signer: ec.opts.ID(), }}, nil } @@ -182,7 +214,7 @@ func (ec *ecdsaBase) Combine(signatures ...hotstuff.QuorumSignature) (hotstuff.Q ts[id] = s } } else { - ec.mods.Logger().Panicf("cannot combine signature of incompatible type %T (expected %T)", sig1, sig2) + ec.logger.Panicf("cannot combine signature of incompatible type %T (expected %T)", sig1, sig2) } } @@ -193,7 +225,7 @@ func (ec *ecdsaBase) Combine(signatures ...hotstuff.QuorumSignature) (hotstuff.Q func (ec *ecdsaBase) Verify(signature hotstuff.QuorumSignature, message []byte) bool { s, ok := signature.(MultiSignature) if !ok { - ec.mods.Logger().Panicf("cannot verify signature of incompatible type %T (expected %T)", signature, s) + ec.logger.Panicf("cannot verify signature of incompatible type %T (expected %T)", signature, s) } n := signature.Participants().Len() @@ -224,7 +256,7 @@ func (ec *ecdsaBase) Verify(signature hotstuff.QuorumSignature, message []byte) func (ec *ecdsaBase) BatchVerify(signature hotstuff.QuorumSignature, batch map[hotstuff.ID][]byte) bool { s, ok := signature.(MultiSignature) if !ok { - ec.mods.Logger().Panicf("cannot verify signature of incompatible type %T (expected %T)", signature, s) + ec.logger.Panicf("cannot verify signature of incompatible type %T (expected %T)", signature, s) } n := signature.Participants().Len() @@ -258,9 +290,9 @@ func (ec *ecdsaBase) BatchVerify(signature hotstuff.QuorumSignature, batch map[h } func (ec *ecdsaBase) verifySingle(sig *Signature, hash hotstuff.Hash) bool { - replica, ok := ec.mods.Configuration().Replica(sig.Signer()) + replica, ok := ec.configuration.Replica(sig.Signer()) if !ok { - ec.mods.Logger().Warnf("ecdsaBase: got signature from replica whose ID (%d) was not in the config.", sig.Signer()) + ec.logger.Warnf("ecdsaBase: got signature from replica whose ID (%d) was not in the config.", sig.Signer()) return false } pk := replica.PublicKey().(*ecdsa.PublicKey) diff --git a/docs/modules.md b/docs/modules.md new file mode 100644 index 00000000..bd1f4c89 --- /dev/null +++ b/docs/modules.md @@ -0,0 +1,286 @@ +# Challenges of a Modular BFT Consensus Implementation + +## Contents + +- [Challenges of a Modular BFT Consensus Implementation](#challenges-of-a-modular-bft-consensus-implementation) + - [Contents](#contents) + - [Introduction](#introduction) + - [The Circular Dependency Problem](#the-circular-dependency-problem) + - [The Best Practices Solution](#the-best-practices-solution) + - [Event-based Indirection](#event-based-indirection) + - [Deferred Initialization](#deferred-initialization) + - [The Composition Problem](#the-composition-problem) + - [The Module System](#the-module-system) + - [A Module Registry](#a-module-registry) + - [Conclusion](#conclusion) + +## Introduction + +This project, called `hotstuff`, started out as an implementation of the HotStuff BFT consensus protocol. +Over time, however, the project has evolved to include implementations of other variants of the HotStuff protocol, +as well as other protocols solving different problems, such as the view synchronization problem (see the view synchronizer), +and scaling problems (see Handel). + +Writing implementations of these protocols is one thing, but the protocols should also be able to function in practice, +interoperating to achieve BFT consensus. +Each protocol may depend upon other components (we will call them modules) in order to perform its functions. +For example, the core consensus protocol (that is HotStuff and its variants) depends on a view synchronizer module (also known as a pacemaker). +Our implementation of the consensus protocol requires the following features of the synchronizer module: + +- Keeps track of the current view (based on the highest known QC, etc.). +- Eventually triggers a proposal in the leader of a view. +- Eventually raises a timeout if progress in the current view has stalled. + +These requirements inform our implementation of the view synchronizer and helps us to design an interface that the consensus protocol can use to interoperate with it. + +When implementing this system, we run into a number of programming challenges. +In this document, we will discuss these problems and assess different solutions to them. + +## The Circular Dependency Problem + +In implementing a system that is modular in the way described above where each module may have multiple different implementations and may also depend on several other such components, +we run into problems when trying to initialize the modules: +To initialize our consensus module, we need to initialize our synchronizer module. +However, our synchronizer module might also depend on our consensus module, +depending on which synchronizer implementation is used. +Here we have a chicken and egg situation. + +### The Best Practices Solution + +Due to the problem described above, as well as other problems, circular dependencies in code are considered bad practice. +Hence, following the "best practices" would call for refactoring of the code to avoid circular dependencies. +Unfortunately, this solution places challenging constraints on what we can do when creating new implementations of a module. +For example, our first implementation of a leader rotation scheme, the round-robin scheme, has no dependencies. +But later, when we want to implement on a more dynamic leader rotation scheme, +we realize that we need historical information about the previous views and thus need to interact with the synchronizer and blockchain modules. +The synchronizer probably needs to know about the identity of the leader, so we now have to refactor these somehow. + +It seems then, that strict adherence to the best practice of avoiding circular dependencies is only going to complicate the matter of creating new module implementations. +Of course, we should keep this best practice in mind, but we should allow circular dependencies in order to make it easier to implement new modules. + +### Event-based Indirection + +This is not a good solution for reasons we'll get to later, but it warrants discussion anyway. +The idea is to use the event system that we have developed to facilitate communication between modules. +Instead of each module having a direct dependency on other modules, +it would depend upon the event system to deliver some request to another module. +The other module would then process the request and emit an event containing the result. + +The problems of this approach are fairly obvious. +First, it is less efficient than calling an interface method. +Second, it introduces a whole host of implementation challenges. + +However, interactions between modules are already in the form of events. +For example, the delivery of messages from the network, or the occurrence of a timeout or proposal. +Therefore, it may make sense to use this solution for some interactions between modules, +but it is not sensible to use this as a solution to the circular dependency problem. + +### Deferred Initialization + +This idea is simple: First create the instances of all the modules. +Afterwards, initialize the modules using the now existing, albeit uninitialized, modules. +Consider the following code example: + +```go +type A struct { b *B } +func NewA() *A { return &A{} } +func (a *A) Init(*b B) { a.b = b } + +type B struct { a *A } +func NewB() *B { return &B{} } +func (b *B) Init(*a A) { b.a = a } + +func main() { + a, b := NewA(), NewB() + a.Init(b) + b.Init(a) +} +``` + +Here we are able to initialize both the `A` and `B` modules which are mutually dependent by deferring their full initialization until after they have been constructed. +This solution works well as long as the modules can exist in an "uninitialized" state. +The solution we have implemented is based on this idea, but it is also connected to our solution to the next problem. + +## The Composition Problem + +Another problem related to the initialization of our modules is this: +How do we build a composition of modules based on a requested configuration. +In other words, how do we select a certain subset of the available modules and initialize them? + +For example, let's say that we want to initialize a system with two kinds of modules. +That is, two different module interfaces. +Let us call them A and B. +Interface A is implemented by modules A1 and A2. +Interface B is implemented by module B1 only. +We will use an arrow notation to express dependency. +Let us say that A1 depends on B, and A2 and B1 are independent. +If A1 is chosen, then an implementation of B must be provided to A1. +However, if A2 is chosen, it does not need a B implementation in order to work. + +A naive implementation of this system could be done like this: + +```go +func compose(choiceA string) { + var ( + a A + b B + ) + b = NewB1() + if choiceA == "A1" { + a = NewA1(b) + } else { + a = NewA2() + } +} +``` + +This naive implementation is not very good because it is difficult to extend when adding new modules. +Also note that this does not include the solution to the circular dependency problem discussed above. +Implementing deferred initialization as discussed above would further extend the amount of boilerplate code in the naive implementation. + +### The Module System + +Our solution to the composition problem and the circular dependency problem is the *module system*. +It is essentially a combination of dependency injection and deferred initialization. +In short, it is a set of interfaces and data structures that simplifies the composition of modules. +The basic idea is this: + +Each module may implement the following `Module` interface. + +```go +type Module interface { + InitModule(mods *modules.Core) +} +``` + +The `InitModule` method is called by the module system to give the module an opportunity to initialize itself. +The module does this by calling the `Get`, `GetAll` or `TryGet` methods of the `modules.Core` object. +These methods take a pointer to the variable where a module should be stored. +The module system then looks for a module of the requested type and stores it in the pointer. + +For example: + +```go +type A1 struct{ b B } + +func (a *A1) InitModule(mods *modules.Core) { + mods.Get(&a.b) +} +``` + +The module system collects a list of modules passed to a `modules.Builder` object. +When the builder's `Build` method is called, all modules added to the builder are initialized, +if they implement the `Module` interface. + +But how does the module system know what interface a module implements? +How does the `Get` method find the correct module to store in the pointer? +Either the modules must explicitly declare what type they want to provide to other modules, +or we could simply check all registered modules to see if any of them "fit" in the pointer. +Go's `reflect` package supports this via a `Type.AssignableTo` method. +For now, this is what we have implemented, as it feels more natural in Go which already has implicit interfaces. +However, an explicit version could work like this: + + + +A `Provider` interface has a single method called `ModuleType()`: + +```go +type Provider interface { + ModuleType() any +} +``` + +This method should return a pointer to the type (typically an interface) that the module provides to other modules. +For a module implementing interface A, this method should simply return `new(A)`. +Additionally, we can make it really easy to implement this interface by using generics: + +```go +type Implements[T any] struct{} + +func (Implements[T]) ModuleType() any { + return new(T) +} +``` + +By embedding `modules.Implements[T]` where T specifies the module's interface, +the module system can become aware of which modules provide what interfaces. +Note that this does not ensure that the module in fact implements `T`, +it is more like an annotation that promises to the module system that `T` is implemented. + +Now we have most of the solution. To compare with the naive code, we now have this: + +```go +func compose(choiceA string) { + builder := modules.NewBuilder() + builder.Add(NewB1()) + if choiceA == "A1" { + builder.Add(NewA1()) + } else { + builder.Add(NewA2()) + } + mods := builder.Build() + + var ( + a A + b B + ) + mods.GetAll(&a, &b) +} +``` + +This doesn't really look a lot better than the naive code, but note the following: + +- Circular dependencies are now supported. +- We did not have to specify in the composition method that A1 depends on interface B. + +### A Module Registry + +To remove the need for an if/else or switch statement in our compose function, we add a *module registry*. +This registry maps the constructor for each module to the module's name. +Each module must register itself with the registry in order to become available. + +Modules register themselves by calling a global `modules.Register` function, providing its name and constructor. +For example: + +```go +func init() { + modules.RegisterModule("A1", NewA1) +} +``` + +Then, a module can be constructed by calling `modules.New` with the name of the requested module. + +After adding the module registry, our compose function looks like this: + +```go +func compose(choiceA string) { + builder := modules.NewBuilder() + builder.Add( + NewB1(), + modules.New(choiceA), + ) + mods := builder.Build() + + var ( + a A + b B + ) + mods.GetAll(&a, &b) +} +``` + +Now, the code is very easy to extend. +If we want to add a new module, we write our implementation, +ensure it implements the `Provider` and `Module` interfaces as necessary, and then register it in the module registry. + +## Conclusion + +This document has described two problems with implementing a modular system like `hotstuff`. +Namely, the circular dependency problem and the composition problem. +These both relate to the initialization of the system. +First, how do we initialize modules that are mutually dependent? +We have seen that requiring all circular dependencies to be removed makes it difficult to implement new modules. +We have seen that an indirection approach using the event system is too complicated. +In the end, we conclude that deferred initialization is the simplest of the three solutions. +Second, how do we compose together the specific module implementations that we want? +We have seen how the module system and module registry greatly simplify this process. diff --git a/eventloop/eventloop.go b/eventloop/eventloop.go index a4bc1762..6bbb3093 100644 --- a/eventloop/eventloop.go +++ b/eventloop/eventloop.go @@ -11,23 +11,55 @@ import ( "reflect" "sync" "time" + + "github.com/relab/hotstuff/util/gpool" ) +type handlerOpts struct { + async bool + priority bool +} + +// HandlerOption sets configuration options for event handlers. +type HandlerOption func(*handlerOpts) + +// RunAsync instructs the eventloop to run the handler asynchronously. +func RunAsync() HandlerOption { + return func(ho *handlerOpts) { + ho.async = true + } +} + +// WithPriority instructs the eventloop to prioritize running the handler before others. +// This guarantees that the handler runs before handlers that have not requested priority. +func WithPriority() HandlerOption { + return func(ho *handlerOpts) { + ho.priority = true + } +} + // EventHandler processes an event. type EventHandler func(event any) +type handler struct { + callback EventHandler + opts handlerOpts +} + // EventLoop accepts events of any type and executes relevant event handlers. // It supports registering both observers and handlers based on the type of event that they accept. // The difference between them is that there can be many observers per event type, but only one handler, // and the handler is executed last. type EventLoop struct { - mut sync.Mutex + eventQ queue + + mut sync.Mutex // protects the following: + + ctx context.Context // set by Run - eventQ queue waitingEvents map[reflect.Type][]any - handlers map[reflect.Type]EventHandler - observers map[reflect.Type][]EventHandler + handlers map[reflect.Type][]handler tickers map[int]*ticker tickerID int @@ -36,37 +68,99 @@ type EventLoop struct { // New returns a new event loop with the requested buffer size. func New(bufferSize uint) *EventLoop { el := &EventLoop{ + ctx: context.Background(), eventQ: newQueue(bufferSize), waitingEvents: make(map[reflect.Type][]any), - handlers: make(map[reflect.Type]EventHandler), - observers: make(map[reflect.Type][]EventHandler), + handlers: make(map[reflect.Type][]handler), tickers: make(map[int]*ticker), } return el } -// RegisterHandler registers a handler for events with the same type as the 'eventType' argument. -// There can be only one handler per event type, and the handler is executed after any observers. -func (el *EventLoop) RegisterHandler(eventType any, handler EventHandler) { - el.handlers[reflect.TypeOf(eventType)] = handler +// RegisterObserver registers a handler with priority. +// Deprecated: use RegisterHandler and the WithPriority option instead. +func (el *EventLoop) RegisterObserver(eventType any, handler EventHandler) int { + return el.RegisterHandler(eventType, handler, WithPriority()) +} + +// UnregisterObserver unregister a handler. +// Deprecated: use UnregisterHandler instead. +func (el *EventLoop) UnregisterObserver(eventType any, id int) { + el.UnregisterHandler(eventType, id) } -// RegisterObserver registers an observer for events with the same type as the 'eventType' argument. -// The observers are executed before the handler. -func (el *EventLoop) RegisterObserver(eventType any, observer EventHandler) { +// RegisterHandler registers an event handler. The handler will +func (el *EventLoop) RegisterHandler(eventType any, callback EventHandler, opts ...HandlerOption) int { + h := handler{callback: callback} + + for _, opt := range opts { + opt(&h.opts) + } + + el.mut.Lock() + defer el.mut.Unlock() t := reflect.TypeOf(eventType) - el.observers[t] = append(el.observers[t], observer) + + handlers := el.handlers[t] + + i := 0 + for ; i < len(handlers); i++ { + if handlers[i].callback == nil { + break + } + } + + if i == len(handlers) { + handlers = append(handlers, h) + } else { + handlers[i] = h + } + + el.handlers[t] = handlers + + return i +} + +// UnregisterHandler unregisters the handler for the given event type with the given id. +func (el *EventLoop) UnregisterHandler(eventType any, id int) { + el.mut.Lock() + defer el.mut.Unlock() + t := reflect.TypeOf(eventType) + el.handlers[t][id].callback = nil } // AddEvent adds an event to the event queue. func (el *EventLoop) AddEvent(event any) { if event != nil { el.eventQ.push(event) + el.processEvent(event, true) } } +// Context returns the context associated with the event loop. +// Usually, this context will be the one passed to Run. +// However, if Tick is used instead of Run, Context will return +// the last context that was passed to Tick. +// If neither Run nor Tick have been called, +// Context returns context.Background. +func (el *EventLoop) Context() context.Context { + el.mut.Lock() + defer el.mut.Unlock() + + return el.ctx +} + +func (el *EventLoop) setContext(ctx context.Context) { + el.mut.Lock() + defer el.mut.Unlock() + + el.ctx = ctx +} + // Run runs the event loop. A context object can be provided to stop the event loop. func (el *EventLoop) Run(ctx context.Context) { + el.setContext(ctx) + loop: for { event, ok := el.eventQ.pop() @@ -82,19 +176,21 @@ loop: el.startTicker(ctx, e.tickerID) continue } - el.processEvent(event) + el.processEvent(event, false) } // HACK: when we get cancelled, we will handle the events that were in the queue at that time before quitting. l := el.eventQ.len() for i := 0; i < l; i++ { event, _ := el.eventQ.pop() - el.processEvent(event) + el.processEvent(event, false) } } // Tick processes a single event. Returns true if an event was handled. -func (el *EventLoop) Tick() bool { +func (el *EventLoop) Tick(ctx context.Context) bool { + el.setContext(ctx) + event, ok := el.eventQ.pop() if !ok { return false @@ -103,41 +199,75 @@ func (el *EventLoop) Tick() bool { if e, ok := event.(startTickerEvent); ok { el.startTicker(context.Background(), e.tickerID) } else { - el.processEvent(event) + el.processEvent(event, false) } return true } +var handlerListPool = gpool.New(func() []EventHandler { return make([]EventHandler, 0, 10) }) + // processEvent dispatches the event to the correct handler. -func (el *EventLoop) processEvent(event any) { +func (el *EventLoop) processEvent(event any, async bool) { t := reflect.TypeOf(event) - defer el.dispatchDelayedEvents(t) + + if !async { + defer el.dispatchDelayedEvents(t) + } if f, ok := event.(func()); ok { f() return } - // run observers - for _, observer := range el.observers[t] { - observer(event) + // Must copy handlers to a list so that they can be executed after unlocking the mutex. + // Use a pool to reduce memory allocations. + priorityList := handlerListPool.Get() + handlerList := handlerListPool.Get() + + el.mut.Lock() + for _, handler := range el.handlers[t] { + if handler.opts.async != async || handler.callback == nil { + continue + } + if handler.opts.priority { + priorityList = append(priorityList, handler.callback) + } else { + handlerList = append(handlerList, handler.callback) + } + } + el.mut.Unlock() + + for _, handler := range priorityList { + handler(event) } - if handler, ok := el.handlers[t]; ok { + priorityList = priorityList[:0] + handlerListPool.Put(priorityList) + + for _, handler := range handlerList { handler(event) } + + handlerList = handlerList[:0] + handlerListPool.Put(handlerList) } func (el *EventLoop) dispatchDelayedEvents(t reflect.Type) { + var ( + events []any + ok bool + ) + el.mut.Lock() - if delayed, ok := el.waitingEvents[t]; ok { - for _, event := range delayed { - el.AddEvent(event) - } + if events, ok = el.waitingEvents[t]; ok { delete(el.waitingEvents, t) } el.mut.Unlock() + + for _, event := range events { + el.AddEvent(event) + } } // DelayUntil allows us to delay handling of an event until after another event has happened. diff --git a/eventloop/eventloop_test.go b/eventloop/eventloop_test.go index 280b598c..d1edc381 100644 --- a/eventloop/eventloop_test.go +++ b/eventloop/eventloop_test.go @@ -156,3 +156,33 @@ func TestDelayedEvent(t *testing.T) { } } } + +func BenchmarkEventLoopWithObservers(b *testing.B) { + el := eventloop.New(100) + + for i := 0; i < 100; i++ { + el.RegisterObserver(testEvent(0), func(event any) { + if event.(testEvent) != 1 { + panic("Unexpected value observed") + } + }) + } + + for i := 0; i < b.N; i++ { + el.AddEvent(testEvent(1)) + el.Tick(context.Background()) + } +} + +func BenchmarkDelay(b *testing.B) { + el := eventloop.New(100) + + for i := 0; i < b.N; i++ { + el.DelayUntil(testEvent(0), testEvent(2)) + el.DelayUntil(testEvent(0), testEvent(3)) + el.AddEvent(testEvent(1)) + el.Tick(context.Background()) + el.Tick(context.Background()) + el.Tick(context.Background()) + } +} diff --git a/go.mod b/go.mod index fb8a9e8e..20765d49 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/mattn/go-isatty v0.0.14 github.com/mitchellh/go-homedir v1.1.0 github.com/mroth/weightedrand v0.4.1 - github.com/relab/gorums v0.7.1-0.20220307181651-94a8af8e467c + github.com/relab/gorums v0.7.1-0.20220818130557-8533cb369cd6 github.com/relab/iago v0.0.0-20220416090249-bf984205c7a8 github.com/relab/wrfs v0.0.0-20220416082020-a641cd350078 github.com/spf13/cobra v1.4.0 diff --git a/go.sum b/go.sum index 70656ed9..5c667c49 100644 --- a/go.sum +++ b/go.sum @@ -636,8 +636,8 @@ github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4O github.com/prometheus/procfs v0.2.0/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU= github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU= -github.com/relab/gorums v0.7.1-0.20220307181651-94a8af8e467c h1:zwhtqr8u1nKoOasB6a88N3pPrWv3Keg+dScV8ZAhXLk= -github.com/relab/gorums v0.7.1-0.20220307181651-94a8af8e467c/go.mod h1:dS1JU8uB1QgQie2bvRPeJWWmIFLPyl5IU50YfWpYVBE= +github.com/relab/gorums v0.7.1-0.20220818130557-8533cb369cd6 h1:azahqG2RhvhFvHiJ5JLhlX8+vViIVe4ZSD4VryHYvfE= +github.com/relab/gorums v0.7.1-0.20220818130557-8533cb369cd6/go.mod h1:dS1JU8uB1QgQie2bvRPeJWWmIFLPyl5IU50YfWpYVBE= github.com/relab/iago v0.0.0-20220416090249-bf984205c7a8 h1:HbeM3xsbEE0pcnKc7E9EOTeWB0hmSaEIaxvgIbJxBso= github.com/relab/iago v0.0.0-20220416090249-bf984205c7a8/go.mod h1:ADclchTQWqG3npAa68T6ueqID28bf+lNKi3j1I26Bgg= github.com/relab/wrfs v0.0.0-20220416082020-a641cd350078 h1:JN5qn8C/HZoyMAycX6z6O0SeX+09CV3w3GcVNW70OZA= diff --git a/handel/handel.go b/handel/handel.go index 7c9bb7fd..20ee5f55 100644 --- a/handel/handel.go +++ b/handel/handel.go @@ -6,7 +6,8 @@ // Enabling Handel from the CLI. // // Handel can be enabled through the `--modules` flag: -// ./hotstuff run --modules="handel" +// +// ./hotstuff run --modules="handel" // // Initialization. // @@ -26,15 +27,17 @@ package handel import ( - "errors" "math" "github.com/relab/gorums" "github.com/relab/hotstuff" "github.com/relab/hotstuff/backend" + "github.com/relab/hotstuff/eventloop" "github.com/relab/hotstuff/internal/proto/handelpb" "github.com/relab/hotstuff/internal/proto/hotstuffpb" + "github.com/relab/hotstuff/logging" "github.com/relab/hotstuff/modules" + "github.com/relab/hotstuff/synchronizer" ) func init() { @@ -43,93 +46,110 @@ func init() { // Handel implements a signature aggregation protocol. type Handel struct { - mods *modules.ConsensusCore + configuration *backend.Config + server *backend.Server + + blockChain modules.BlockChain + crypto modules.Crypto + eventLoop *eventloop.EventLoop + logger logging.Logger + opts *modules.Options + synchronizer modules.Synchronizer + nodes map[hotstuff.ID]*handelpb.Node maxLevel int sessions map[hotstuff.Hash]*session + initDone bool } // New returns a new instance of the Handel module. func New() modules.Handel { return &Handel{ - nodes: make(map[hotstuff.ID]*handelpb.Node), + nodes: make(map[hotstuff.ID]*handelpb.Node), + sessions: make(map[hotstuff.Hash]*session), } } -// InitModule gives the module a reference to the ConsensusCore object. -// It also allows the module to set module options using the OptionsBuilder. -func (h *Handel) InitModule(mods *modules.ConsensusCore, opts *modules.OptionsBuilder) { - h.mods = mods - opts.SetShouldUseHandel() - - // the rest of the setup is deferred to the Init method. - // FIXME: it could be possible to handle the Init stuff automatically - // if the Configuration module were to send an event upon connecting. -} +// InitModule initializes the Handel module. +func (h *Handel) InitModule(mods *modules.Core) { + mods.GetAll( + &h.configuration, + &h.server, -// Init initializes the Handel module. -func (h *Handel) Init() error { - h.mods.Logger().Info("Handel: Initializing") + &h.blockChain, + &h.crypto, + &h.eventLoop, + &h.logger, + &h.opts, + &h.synchronizer, + ) - h.sessions = make(map[hotstuff.Hash]*session) - - var cfg *backend.Config - var srv *backend.Server - - if !h.mods.GetModuleByType(&srv) { - return errors.New("could not get gorums server") - } - if !h.mods.GetModuleByType(&cfg) { - return errors.New("could not get gorums configuration") - } + h.opts.SetShouldUseHandel() - handelpb.RegisterHandelServer(srv.GetGorumsServer(), serviceImpl{h}) - handelCfg := handelpb.ConfigurationFromRaw(cfg.GetRawConfiguration(), nil) - - for _, n := range handelCfg.Nodes() { - h.nodes[hotstuff.ID(n.ID())] = n - } - - h.maxLevel = int(math.Ceil(math.Log2(float64(h.mods.Configuration().Len())))) + h.eventLoop.RegisterObserver(backend.ConnectedEvent{}, func(_ any) { + h.postInit() + }) - h.mods.EventLoop().RegisterHandler(contribution{}, func(event any) { + h.eventLoop.RegisterHandler(contribution{}, func(event any) { c := event.(contribution) if s, ok := h.sessions[c.hash]; ok { s.handleContribution(c) } else if !c.deferred { c.deferred = true - h.mods.EventLoop().DelayUntil(hotstuff.ProposeMsg{}, c) + h.eventLoop.DelayUntil(hotstuff.ProposeMsg{}, c) } }) - h.mods.EventLoop().RegisterHandler(disseminateEvent{}, func(e any) { + h.eventLoop.RegisterHandler(sessionDoneEvent{}, func(event any) { + e := event.(sessionDoneEvent) + delete(h.sessions, e.hash) + }) +} + +func (h *Handel) postInit() { + h.logger.Info("Handel: Initializing") + + h.maxLevel = int(math.Ceil(math.Log2(float64(h.configuration.Len())))) + + handelCfg := handelpb.ConfigurationFromRaw(h.configuration.GetRawConfiguration(), nil) + for _, n := range handelCfg.Nodes() { + h.nodes[hotstuff.ID(n.ID())] = n + } + + handelpb.RegisterHandelServer(h.server.GetGorumsServer(), serviceImpl{h}) + + // now we can start handling timer events + h.eventLoop.RegisterHandler(disseminateEvent{}, func(e any) { + ctx, cancel := synchronizer.ViewContext(h.eventLoop.Context(), h.eventLoop, nil) + defer cancel() if s, ok := h.sessions[e.(disseminateEvent).sessionID]; ok { - s.sendContributions(s.h.mods.Synchronizer().ViewContext()) + s.sendContributions(ctx) } }) - h.mods.EventLoop().RegisterHandler(levelActivateEvent{}, func(e any) { + h.eventLoop.RegisterHandler(levelActivateEvent{}, func(e any) { if s, ok := h.sessions[e.(levelActivateEvent).sessionID]; ok { s.advanceLevel() } }) - h.mods.EventLoop().RegisterHandler(sessionDoneEvent{}, func(event any) { - e := event.(sessionDoneEvent) - delete(h.sessions, e.hash) - }) - - return nil + h.initDone = true } // Begin commissions the aggregation of a new signature. func (h *Handel) Begin(s hotstuff.PartialCert) { + if !h.initDone { + // wait until initialization is done + h.eventLoop.DelayUntil(backend.ConnectedEvent{}, func() { h.Begin(s) }) + return + } + // turn the single signature into a threshold signature, // this makes it easier to work with. session := h.newSession(s.BlockHash(), s.Signature()) h.sessions[s.BlockHash()] = session - go session.verifyContributions(h.mods.Synchronizer().ViewContext()) + go session.verifyContributions() } type serviceImpl struct { @@ -140,16 +160,16 @@ func (impl serviceImpl) Contribute(ctx gorums.ServerCtx, msg *handelpb.Contribut var hash hotstuff.Hash copy(hash[:], msg.GetHash()) - id, err := backend.GetPeerIDFromContext(ctx, impl.h.mods.Configuration()) + id, err := backend.GetPeerIDFromContext(ctx, impl.h.configuration) if err != nil { - impl.h.mods.Logger().Error(err) + impl.h.logger.Error(err) } sig := hotstuffpb.QuorumSignatureFromProto(msg.GetSignature()) indiv := hotstuffpb.QuorumSignatureFromProto(msg.GetIndividual()) if sig != nil && indiv != nil { - impl.h.mods.EventLoop().AddEvent(contribution{ + impl.h.eventLoop.AddEvent(contribution{ hash: hash, sender: id, level: int(msg.GetLevel()), @@ -158,7 +178,7 @@ func (impl serviceImpl) Contribute(ctx gorums.ServerCtx, msg *handelpb.Contribut verified: false, }) } else { - impl.h.mods.Logger().Warnf("contribution received with invalid signatures: %v, %v", sig, indiv) + impl.h.logger.Warnf("contribution received with invalid signatures: %v, %v", sig, indiv) } } diff --git a/handel/session.go b/handel/session.go index c3ca197d..87a22b24 100644 --- a/handel/session.go +++ b/handel/session.go @@ -1,6 +1,7 @@ package handel import ( + "bytes" "context" "encoding/binary" "math/rand" @@ -13,6 +14,8 @@ import ( "github.com/relab/hotstuff" "github.com/relab/hotstuff/internal/proto/handelpb" "github.com/relab/hotstuff/internal/proto/hotstuffpb" + "github.com/relab/hotstuff/synchronizer" + "github.com/relab/hotstuff/util/gpool" ) const ( @@ -92,11 +95,11 @@ func (h *Handel) newSession(hash hotstuff.Hash, in hotstuff.QuorumSignature) *se s := &session{ h: h, hash: hash, - seed: h.mods.Options().SharedRandomSeed() + int64(binary.LittleEndian.Uint64(hash[:])), + seed: h.opts.SharedRandomSeed() + int64(binary.LittleEndian.Uint64(hash[:])), window: window{ - window: h.mods.Configuration().Len(), - max: h.mods.Configuration().Len(), + window: h.configuration.Len(), + max: h.configuration.Len(), min: 2, increaseFactor: 2, decreaseFactor: 4, @@ -106,8 +109,8 @@ func (h *Handel) newSession(hash hotstuff.Hash, in hotstuff.QuorumSignature) *se // Get a sorted list of IDs for all replicas. // The configuration should also contain our own ID. - ids := make([]hotstuff.ID, 0, h.mods.Configuration().Len()) - for id := range h.mods.Configuration().Replicas() { + ids := make([]hotstuff.ID, 0, h.configuration.Len()) + for id := range h.configuration.Replicas() { ids = append(ids, id) } sort.Slice(ids, func(i, j int) bool { return ids[i] < ids[j] }) @@ -116,28 +119,28 @@ func (h *Handel) newSession(hash hotstuff.Hash, in hotstuff.QuorumSignature) *se rnd := rand.New(rand.NewSource(s.seed)) rnd.Shuffle(len(ids), reflect.Swapper(ids)) - h.mods.Logger().Debugf("Handel session ids: %v", ids) + h.logger.Debugf("Handel session ids: %v", ids) - s.part = newPartitioner(h.mods.ID(), ids) + s.part = newPartitioner(h.opts.ID(), ids) s.levels = make([]level, h.maxLevel+1) for i := range s.levels { s.levels[i] = s.newLevel(i) min, max := s.part.rangeLevel(i) - h.mods.Logger().Debugf("level %d: %v", i, s.part.ids[min:max+1]) + h.logger.Debugf("level %d: %v", i, s.part.ids[min:max+1]) } - s.levels[0].individual[h.mods.ID()] = in + s.levels[0].individual[h.opts.ID()] = in s.levels[0].incoming = in s.updateOutgoing(1) - s.disseminateTimerID = h.mods.EventLoop().AddTicker(disseminationPeriod, func(_ time.Time) (event any) { + s.disseminateTimerID = h.eventLoop.AddTicker(disseminationPeriod, func(_ time.Time) (event any) { return disseminateEvent{s.hash} }) - s.levelActivateTimerID = h.mods.EventLoop().AddTicker(levelActivateInterval, func(_ time.Time) (event any) { + s.levelActivateTimerID = h.eventLoop.AddTicker(levelActivateInterval, func(_ time.Time) (event any) { return levelActivateEvent{s.hash} }) @@ -156,8 +159,8 @@ type level struct { func (s *session) newLevel(i int) level { return level{ - vp: verificationPriority(s.part.ids, s.seed, s.h.mods.ID(), i), - cp: contributionPriority(s.part.ids, s.seed, s.h.mods.ID(), i), + vp: verificationPriority(s.part.ids, s.seed, s.h.opts.ID(), i), + cp: contributionPriority(s.part.ids, s.seed, s.h.opts.ID(), i), individual: make(map[hotstuff.ID]hotstuff.QuorumSignature), } } @@ -194,7 +197,7 @@ func (s *session) score(contribution contribution) int { need := s.part.size(contribution.level) if contribution.level == s.h.maxLevel { - need = s.h.mods.Configuration().QuorumSize() + need = s.h.configuration.QuorumSize() } curBest := level.incoming @@ -254,7 +257,7 @@ func (s *session) score(contribution contribution) int { } } - s.h.mods.Logger().Debugf("level: %d, need: %d, added: %d, total: %d, score: %d", contribution.level, need, added, total, score) + s.h.logger.Debugf("level: %d, need: %d, added: %d, total: %d, score: %d", contribution.level, need, added, total, score) return score } @@ -298,7 +301,7 @@ func (s *session) insertPending(c contribution) { } level.pending[i] = c - s.h.mods.Logger().Debugf("pending contribution at level %d with score %d from sender %d", c.level, score, c.sender) + s.h.logger.Debugf("pending contribution at level %d with score %d from sender %d", c.level, score, c.sender) // notify verification goroutine select { @@ -315,7 +318,7 @@ func (s *session) updateIncoming(c contribution) { // check if there is a new individual signature if _, ok := level.individual[c.sender]; !ok { - s.h.mods.Logger().Debugf("New individual signature from %d for level %d", c.sender, c.level) + s.h.logger.Debugf("New individual signature from %d for level %d", c.sender, c.level) level.individual[c.sender] = c.signature } @@ -324,14 +327,16 @@ func (s *session) updateIncoming(c contribution) { return } - s.h.mods.Logger().Debugf("New incoming aggregate signature for level %d with length %d", c.level, c.signature.Participants().Len()) + s.h.logger.Debugf("New incoming aggregate signature for level %d with length %d", c.level, c.signature.Participants().Len()) level.incoming = c.signature if s.isLevelComplete(c.level) { level.done = true s.advanceLevel() if c.level+1 <= s.h.maxLevel { - s.sendFastPath(s.h.mods.Synchronizer().ViewContext(), c.level+1) + ctx, cancel := synchronizer.ViewContext(s.h.eventLoop.Context(), s.h.eventLoop, nil) + defer cancel() + s.sendFastPath(ctx, c.level+1) } } @@ -361,33 +366,33 @@ func (s *session) updateOutgoing(levelIndex int) { } else if prevLevel.incoming == nil { outgoing = prevLevel.outgoing } else { - outgoing, err = s.h.mods.Crypto().Combine(prevLevel.incoming, prevLevel.outgoing) + outgoing, err = s.h.crypto.Combine(prevLevel.incoming, prevLevel.outgoing) if err != nil { - s.h.mods.Logger().Errorf("Failed to combine incoming and outgoing for level %d: %v", levelIndex, err) + s.h.logger.Errorf("Failed to combine incoming and outgoing for level %d: %v", levelIndex, err) return } } if levelIndex > s.h.maxLevel { - if outgoing.Participants().Len() >= s.h.mods.Configuration().QuorumSize() { - s.h.mods.Logger().Debugf("Done with session: %.8s", s.hash) + if outgoing.Participants().Len() >= s.h.configuration.QuorumSize() { + s.h.logger.Debugf("Done with session: %.8s", s.hash) - s.h.mods.EventLoop().AddEvent(hotstuff.NewViewMsg{ + s.h.eventLoop.AddEvent(hotstuff.NewViewMsg{ SyncInfo: hotstuff.NewSyncInfo().WithQC(hotstuff.NewQuorumCert( outgoing, - s.h.mods.Synchronizer().View(), + s.h.synchronizer.View(), s.hash, )), }) - s.h.mods.EventLoop().AddEvent(sessionDoneEvent{s.hash}) + s.h.eventLoop.AddEvent(sessionDoneEvent{s.hash}) } } else { level := &s.levels[levelIndex] level.outgoing = outgoing - s.h.mods.Logger().Debugf("Updated outgoing for level %d: %v", levelIndex, outgoing.Participants()) + s.h.logger.Debugf("Updated outgoing for level %d: %v", levelIndex, outgoing.Participants()) if levelIndex <= s.h.maxLevel { s.updateOutgoing(levelIndex + 1) @@ -415,7 +420,7 @@ func (s *session) advanceLevel() { s.activeLevelIndex++ - s.h.mods.Logger().Debugf("advanced to level %d", s.activeLevelIndex) + s.h.logger.Debugf("advanced to level %d", s.activeLevelIndex) } func (s *session) sendContributions(ctx context.Context) { @@ -428,7 +433,7 @@ func (s *session) sendContributions(ctx context.Context) { } func (s *session) sendFastPath(ctx context.Context, levelIndex int) { - s.h.mods.Logger().Debug("fast path activated") + s.h.logger.Debug("fast path activated") n := s.part.size(levelIndex) if n > 10 { @@ -460,7 +465,7 @@ func (s *session) sendContributionToLevel(ctx context.Context, levelIndex int) { if node, ok := s.h.nodes[id]; ok { node.Contribute(ctx, &handelpb.Contribution{ - ID: uint32(s.h.mods.ID()), + ID: uint32(s.h.opts.ID()), Level: uint32(levelIndex), Signature: hotstuffpb.QuorumSignatureToProto(level.outgoing), Individual: hotstuffpb.QuorumSignatureToProto(s.levels[0].incoming), @@ -472,7 +477,10 @@ func (s *session) sendContributionToLevel(ctx context.Context, levelIndex int) { level.cp[id] += len(s.part.ids) } -func (s *session) verifyContributions(ctx context.Context) { +func (s *session) verifyContributions() { + ctx, cancel := synchronizer.ViewContext(s.h.eventLoop.Context(), s.h.eventLoop, nil) + defer cancel() + for ctx.Err() == nil { c, verifyIndiv, ok := s.chooseContribution() if !ok { @@ -487,9 +495,8 @@ func (s *session) verifyContributions(ctx context.Context) { s.verifyContribution(c, sig, verifyIndiv) } - s.h.mods.EventLoop().RemoveTicker(s.disseminateTimerID) - s.h.mods.EventLoop().RemoveTicker(s.levelActivateTimerID) - + s.h.eventLoop.RemoveTicker(s.disseminateTimerID) + s.h.eventLoop.RemoveTicker(s.levelActivateTimerID) } // chooseContribution chooses the next contribution to verify. @@ -545,7 +552,7 @@ func (s *session) chooseContribution() (cont contribution, verifyIndiv, ok bool) } best := choices[bestChoiceIndex] - s.h.mods.Logger().Debugf("Chose: %v", best.signature.Participants()) + s.h.logger.Debugf("Chose: %v", best.signature.Participants()) _, verifyIndiv = s.levels[best.level].individual[best.sender] @@ -596,22 +603,22 @@ func (s *session) improveSignature(contribution contribution) hotstuff.QuorumSig signature := contribution.signature if s.canMergeContributions(signature, level.incoming) { - new, err := s.h.mods.Crypto().Combine(signature, level.incoming) + new, err := s.h.crypto.Combine(signature, level.incoming) if err == nil { signature = new } else { - s.h.mods.Logger().Errorf("Failed to combine signatures: %v", err) + s.h.logger.Errorf("Failed to combine signatures: %v", err) } } // add any individual signature, if possible for _, indiv := range level.individual { if s.canMergeContributions(signature, indiv) { - new, err := s.h.mods.Crypto().Combine(signature, indiv) + new, err := s.h.crypto.Combine(signature, indiv) if err == nil { signature = new } else { - s.h.mods.Logger().Errorf("Failed to combine signatures: %v", err) + s.h.logger.Errorf("Failed to combine signatures: %v", err) } } } @@ -619,35 +626,48 @@ func (s *session) improveSignature(contribution contribution) hotstuff.QuorumSig return signature } +var bufferPool gpool.Pool[bytes.Buffer] + func (s *session) verifyContribution(c contribution, sig hotstuff.QuorumSignature, verifyIndiv bool) { - block, ok := s.h.mods.BlockChain().Get(s.hash) + block, ok := s.h.blockChain.Get(s.hash) if !ok { return } - s.h.mods.Logger().Debugf("verifying: %v (= %d)", sig.Participants(), sig.Participants().Len()) + buf := bufferPool.Get() + _, err := block.WriteTo(&buf) + if err != nil { + return + } + + defer func() { + buf.Reset() + bufferPool.Put(buf) + }() + + s.h.logger.Debugf("verifying: %v (= %d)", sig.Participants(), sig.Participants().Len()) aggVerified := false - if s.h.mods.Crypto().Verify(sig, block.ToBytes()) { + if s.h.crypto.Verify(sig, buf.Bytes()) { aggVerified = true } else { - s.h.mods.Logger().Debug("failed to verify aggregate signature") + s.h.logger.Debug("failed to verify aggregate signature") } indivVerified := false // If the contribution is individual, we want to verify it separately if verifyIndiv { - if s.h.mods.Crypto().Verify(c.individual, block.ToBytes()) { + if s.h.crypto.Verify(c.individual, buf.Bytes()) { indivVerified = true } else { - s.h.mods.Logger().Debug("failed to verify individual signature") + s.h.logger.Debug("failed to verify individual signature") } } indivOk := (indivVerified || !verifyIndiv) if indivOk && aggVerified { - s.h.mods.EventLoop().AddEvent(contribution{ + s.h.eventLoop.AddEvent(contribution{ hash: s.hash, sender: c.sender, level: c.level, @@ -656,10 +676,10 @@ func (s *session) verifyContribution(c contribution, sig hotstuff.QuorumSignatur verified: true, }) - s.h.mods.Logger().Debug("window increased") + s.h.logger.Debug("window increased") s.window.increase() } else { - s.h.mods.Logger().Debugf("window decreased (indiv: %v, agg: %v)", indivOk, aggVerified) + s.h.logger.Debugf("window decreased (indiv: %v, agg: %v)", indivOk, aggVerified) s.window.decrease() } } diff --git a/internal/cli/run.go b/internal/cli/run.go index 93d13256..7c10aefc 100644 --- a/internal/cli/run.go +++ b/internal/cli/run.go @@ -18,7 +18,7 @@ import ( "github.com/relab/hotstuff/internal/proto/orchestrationpb" "github.com/relab/hotstuff/internal/protostream" "github.com/relab/hotstuff/logging" - "github.com/relab/hotstuff/modules" + "github.com/relab/hotstuff/metrics" "github.com/relab/iago" "github.com/spf13/cobra" "github.com/spf13/viper" @@ -230,7 +230,7 @@ func parseByzantine() (map[string]int, error) { return strategies, nil } -func localWorker(globalOutput string, metrics []string, interval time.Duration) (worker orchestration.RemoteWorker, wait func()) { +func localWorker(globalOutput string, enableMetrics []string, interval time.Duration) (worker orchestration.RemoteWorker, wait func()) { // set up an output dir output := "" if globalOutput != "" { @@ -247,7 +247,7 @@ func localWorker(globalOutput string, metrics []string, interval time.Duration) controllerPipe, workerPipe := net.Pipe() c := make(chan struct{}) go func() { - var logger modules.MetricsLogger + var logger metrics.Logger if output != "" { f, err := os.OpenFile(filepath.Join(output, "measurements.json"), os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) checkf("failed to create output file: %v", err) @@ -256,18 +256,18 @@ func localWorker(globalOutput string, metrics []string, interval time.Duration) wr := bufio.NewWriter(f) defer func() { checkf("failed to flush writer: %v", wr.Flush()) }() - logger, err = modules.NewJSONLogger(wr) + logger, err = metrics.NewJSONLogger(wr) checkf("failed to create JSON logger: %v", err) defer func() { checkf("failed to close logger: %v", logger.Close()) }() } else { - logger = modules.NopLogger() + logger = metrics.NopLogger() } worker := orchestration.NewWorker( protostream.NewWriter(workerPipe), protostream.NewReader(workerPipe), logger, - metrics, + enableMetrics, interval, ) diff --git a/internal/cli/worker.go b/internal/cli/worker.go index 9db8be05..98bc8d1d 100644 --- a/internal/cli/worker.go +++ b/internal/cli/worker.go @@ -9,7 +9,7 @@ import ( "github.com/relab/hotstuff/internal/orchestration" "github.com/relab/hotstuff/internal/profiling" "github.com/relab/hotstuff/internal/protostream" - "github.com/relab/hotstuff/modules" + "github.com/relab/hotstuff/metrics" "github.com/spf13/cobra" ) @@ -20,7 +20,7 @@ var ( trace string fgprofProfile string - metrics []string + enableMetrics []string measurementInterval time.Duration ) @@ -54,7 +54,7 @@ func init() { workerCmd.Flags().StringVar(&trace, "trace", "", "Path to store a trace") workerCmd.Flags().StringVar(&fgprofProfile, "fgprof-profile", "", "Path to store a fgprof profile") - workerCmd.Flags().StringSliceVar(&metrics, "metrics", nil, "the metrics to enable") + workerCmd.Flags().StringSliceVar(&enableMetrics, "metrics", nil, "the metrics to enable") workerCmd.Flags().DurationVar(&measurementInterval, "measurement-interval", 0, "the interval between measurements") } @@ -66,12 +66,12 @@ func runWorker() { checkf("failed to stop profilers: %v", err) }() - metricsLogger := modules.NopLogger() + metricsLogger := metrics.NopLogger() if dataPath != "" { f, err := os.OpenFile(dataPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) checkf("failed to create data path: %v", err) writer := bufio.NewWriter(f) - metricsLogger, err = modules.NewJSONLogger(writer) + metricsLogger, err = metrics.NewJSONLogger(writer) defer func() { err = metricsLogger.Close() checkf("failed to close metrics logger: %v", err) @@ -82,7 +82,7 @@ func runWorker() { }() } - worker := orchestration.NewWorker(protostream.NewWriter(os.Stdout), protostream.NewReader(os.Stdin), metricsLogger, metrics, measurementInterval) + worker := orchestration.NewWorker(protostream.NewWriter(os.Stdout), protostream.NewReader(os.Stdin), metricsLogger, enableMetrics, measurementInterval) err = worker.Run() if err != nil { log.Println(err) diff --git a/internal/mocks/forkhandler_mock.go b/internal/mocks/forkhandler_mock.go new file mode 100644 index 00000000..5bb16802 --- /dev/null +++ b/internal/mocks/forkhandler_mock.go @@ -0,0 +1,47 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/relab/hotstuff/modules (interfaces: ForkHandler) + +// Package mocks is a generated GoMock package. +package mocks + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + hotstuff "github.com/relab/hotstuff" +) + +// MockForkHandler is a mock of ForkHandler interface. +type MockForkHandler struct { + ctrl *gomock.Controller + recorder *MockForkHandlerMockRecorder +} + +// MockForkHandlerMockRecorder is the mock recorder for MockForkHandler. +type MockForkHandlerMockRecorder struct { + mock *MockForkHandler +} + +// NewMockForkHandler creates a new mock instance. +func NewMockForkHandler(ctrl *gomock.Controller) *MockForkHandler { + mock := &MockForkHandler{ctrl: ctrl} + mock.recorder = &MockForkHandlerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockForkHandler) EXPECT() *MockForkHandlerMockRecorder { + return m.recorder +} + +// Fork mocks base method. +func (m *MockForkHandler) Fork(arg0 hotstuff.Command) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Fork", arg0) +} + +// Fork indicates an expected call of Fork. +func (mr *MockForkHandlerMockRecorder) Fork(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Fork", reflect.TypeOf((*MockForkHandler)(nil).Fork), arg0) +} diff --git a/internal/mocks/synchronizer_mock.go b/internal/mocks/synchronizer_mock.go index d23e1f5e..beb1ff3a 100644 --- a/internal/mocks/synchronizer_mock.go +++ b/internal/mocks/synchronizer_mock.go @@ -87,18 +87,6 @@ func (mr *MockSynchronizerMockRecorder) Start(arg0 interface{}) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Start", reflect.TypeOf((*MockSynchronizer)(nil).Start), arg0) } -// UpdateHighQC mocks base method. -func (m *MockSynchronizer) UpdateHighQC(arg0 hotstuff.QuorumCert) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdateHighQC", arg0) -} - -// UpdateHighQC indicates an expected call of UpdateHighQC. -func (mr *MockSynchronizerMockRecorder) UpdateHighQC(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateHighQC", reflect.TypeOf((*MockSynchronizer)(nil).UpdateHighQC), arg0) -} - // View mocks base method. func (m *MockSynchronizer) View() hotstuff.View { m.ctrl.T.Helper() diff --git a/internal/orchestration/orchestration_test.go b/internal/orchestration/orchestration_test.go index 61475460..27236ae8 100644 --- a/internal/orchestration/orchestration_test.go +++ b/internal/orchestration/orchestration_test.go @@ -17,7 +17,7 @@ import ( "github.com/relab/hotstuff/internal/proto/orchestrationpb" "github.com/relab/hotstuff/internal/protostream" "github.com/relab/hotstuff/logging" - "github.com/relab/hotstuff/modules" + "github.com/relab/hotstuff/metrics" "github.com/relab/iago/iagotest" "google.golang.org/protobuf/types/known/durationpb" ) @@ -27,7 +27,7 @@ func TestOrchestration(t *testing.T) { controllerStream, workerStream := net.Pipe() workerProxy := orchestration.NewRemoteWorker(protostream.NewWriter(controllerStream), protostream.NewReader(controllerStream)) - worker := orchestration.NewWorker(protostream.NewWriter(workerStream), protostream.NewReader(workerStream), modules.NopLogger(), nil, 0) + worker := orchestration.NewWorker(protostream.NewWriter(workerStream), protostream.NewReader(workerStream), metrics.NopLogger(), nil, 0) experiment := &orchestration.Experiment{ Logger: logging.New("ctrl"), @@ -51,7 +51,7 @@ func TestOrchestration(t *testing.T) { LeaderRotation: "round-robin", Modules: mods, }, - Duration: 1 * time.Second, + Duration: 5 * time.Second, Hosts: map[string]orchestration.RemoteWorker{"127.0.0.1": workerProxy}, } @@ -113,7 +113,7 @@ func TestDeployment(t *testing.T) { Crypto: "ecdsa", LeaderRotation: "round-robin", }, - Duration: 1 * time.Second, + Duration: 10 * time.Second, Hosts: make(map[string]orchestration.RemoteWorker), } diff --git a/internal/orchestration/worker.go b/internal/orchestration/worker.go index 3197bd95..2b35322b 100644 --- a/internal/orchestration/worker.go +++ b/internal/orchestration/worker.go @@ -19,7 +19,7 @@ import ( "github.com/relab/hotstuff/consensus/byzantine" "github.com/relab/hotstuff/crypto" "github.com/relab/hotstuff/crypto/keygen" - "github.com/relab/hotstuff/handel" + "github.com/relab/hotstuff/eventloop" "github.com/relab/hotstuff/internal/proto/orchestrationpb" "github.com/relab/hotstuff/internal/protostream" "github.com/relab/hotstuff/logging" @@ -48,7 +48,7 @@ type Worker struct { send *protostream.Writer recv *protostream.Reader - metricsLogger modules.MetricsLogger + metricsLogger metrics.Logger metrics []string measurementInterval time.Duration @@ -93,7 +93,7 @@ func (w *Worker) Run() error { } // NewWorker returns a new worker. -func NewWorker(send *protostream.Writer, recv *protostream.Reader, dl modules.MetricsLogger, metrics []string, measurementInterval time.Duration) Worker { +func NewWorker(send *protostream.Writer, recv *protostream.Reader, dl metrics.Logger, metrics []string, measurementInterval time.Duration) Worker { return Worker{ send: send, recv: recv, @@ -163,7 +163,7 @@ func (w *Worker) createReplica(opts *orchestrationpb.ReplicaOpts) (*replica.Repl rootCAs.AppendCertsFromPEM(opts.GetCertificateAuthority()) } // prepare modules - builder := modules.NewConsensusBuilder(hotstuff.ID(opts.GetID()), privKey) + builder := modules.NewBuilder(hotstuff.ID(opts.GetID()), privKey) consensusRules, ok := modules.GetModule[consensus.Rules](opts.GetConsensus()) if !ok { @@ -195,7 +195,8 @@ func (w *Worker) createReplica(opts *orchestrationpb.ReplicaOpts) (*replica.Repl float64(opts.GetTimeoutMultiplier()), )) - builder.Register( + builder.Add( + eventloop.New(1000), consensus.New(consensusRules), consensus.NewVotingMachine(), crypto.NewCache(cryptoImpl, 100), // TODO: consider making this configurable @@ -206,12 +207,12 @@ func (w *Worker) createReplica(opts *orchestrationpb.ReplicaOpts) (*replica.Repl logging.New("hs"+strconv.Itoa(int(opts.GetID()))), ) - builder.OptionsBuilder().SetSharedRandomSeed(opts.GetSharedSeed()) + builder.Options().SetSharedRandomSeed(opts.GetSharedSeed()) if w.measurementInterval > 0 { replicaMetrics := metrics.GetReplicaMetrics(w.metrics...) - builder.Register(replicaMetrics...) - builder.Register(metrics.NewTicker(w.measurementInterval)) + builder.Add(replicaMetrics...) + builder.Add(metrics.NewTicker(w.measurementInterval)) } for _, n := range opts.GetModules() { @@ -219,7 +220,7 @@ func (w *Worker) createReplica(opts *orchestrationpb.ReplicaOpts) (*replica.Repl if !ok { return nil, fmt.Errorf("no module named '%s'", n) } - builder.Register(m) + builder.Add(m) } c := replica.Config{ @@ -253,15 +254,6 @@ func (w *Worker) startReplicas(req *orchestrationpb.StartReplicaRequest) (*orche return nil, err } - // start Handel if enabled - var h *handel.Handel - if replica.Modules().GetModuleByType(&h) { - err = h.Init() - if err != nil { - return nil, err - } - } - defer func(id uint32) { w.metricsLogger.Log(&types.StartEvent{Event: types.NewReplicaEvent(id, time.Now())}) replica.Start() @@ -308,16 +300,17 @@ func (w *Worker) startClients(req *orchestrationpb.StartClientRequest) (*orchest RateStepInterval: opts.GetRateStepInterval().AsDuration(), Timeout: opts.GetTimeout().AsDuration(), } - mods := modules.NewCoreBuilder(hotstuff.ID(opts.GetID())) + mods := modules.NewBuilder(hotstuff.ID(opts.GetID()), nil) + mods.Add(eventloop.New(1000)) if w.measurementInterval > 0 { clientMetrics := metrics.GetClientMetrics(w.metrics...) - mods.Register(clientMetrics...) - mods.Register(metrics.NewTicker(w.measurementInterval)) + mods.Add(clientMetrics...) + mods.Add(metrics.NewTicker(w.measurementInterval)) } - mods.Register(w.metricsLogger) - mods.Register(logging.New("cli" + strconv.Itoa(int(opts.GetID())))) + mods.Add(w.metricsLogger) + mods.Add(logging.New("cli" + strconv.Itoa(int(opts.GetID())))) cli := client.New(c, mods) cfg, err := getConfiguration(req.GetConfiguration(), true) if err != nil { diff --git a/internal/proto/hotstuffpb/convert.go b/internal/proto/hotstuffpb/convert.go index 0f993740..b2f1ed19 100644 --- a/internal/proto/hotstuffpb/convert.go +++ b/internal/proto/hotstuffpb/convert.go @@ -2,6 +2,8 @@ package hotstuffpb import ( "math/big" + "reflect" + "unsafe" "github.com/relab/hotstuff" "github.com/relab/hotstuff/crypto" @@ -68,9 +70,10 @@ func PartialCertToProto(cert hotstuff.PartialCert) *PartialCert { // PartialCertFromProto converts a hotstuffpb.PartialCert to an ecdsa.PartialCert. func PartialCertFromProto(cert *PartialCert) hotstuff.PartialCert { - var h hotstuff.Hash - copy(h[:], cert.GetHash()) - return hotstuff.NewPartialCert(QuorumSignatureFromProto(cert.GetSig()), h) + return hotstuff.NewPartialCert( + QuorumSignatureFromProto(cert.GetSig()), + convertHash(cert.GetHash()), + ) } // QuorumCertToProto converts a consensus.QuorumCert to a hotstuffpb.QuorumCert. @@ -85,9 +88,11 @@ func QuorumCertToProto(qc hotstuff.QuorumCert) *QuorumCert { // QuorumCertFromProto converts a hotstuffpb.QuorumCert to an ecdsa.QuorumCert. func QuorumCertFromProto(qc *QuorumCert) hotstuff.QuorumCert { - var h hotstuff.Hash - copy(h[:], qc.GetHash()) - return hotstuff.NewQuorumCert(QuorumSignatureFromProto(qc.GetSig()), hotstuff.View(qc.GetView()), h) + return hotstuff.NewQuorumCert( + QuorumSignatureFromProto(qc.GetSig()), + hotstuff.View(qc.GetView()), + convertHash(qc.GetHash()), + ) } // ProposalToProto converts a ProposeMsg to a protobuf message. @@ -116,7 +121,7 @@ func BlockToProto(block *hotstuff.Block) *Block { parentHash := block.Parent() return &Block{ Parent: parentHash[:], - Command: []byte(block.Command()), + Command: unsafeStringToBytes(block.Command()), QC: QuorumCertToProto(block.QuorumCert()), View: uint64(block.View()), Proposer: uint32(block.Proposer()), @@ -125,12 +130,10 @@ func BlockToProto(block *hotstuff.Block) *Block { // BlockFromProto converts a hotstuffpb.Block to a consensus.Block. func BlockFromProto(block *Block) *hotstuff.Block { - var p hotstuff.Hash - copy(p[:], block.GetParent()) return hotstuff.NewBlock( - p, + convertHash(block.GetParent()), QuorumCertFromProto(block.GetQC()), - hotstuff.Command(block.GetCommand()), + unsafeBytesToString(block.GetCommand()), hotstuff.View(block.GetView()), hotstuff.ID(block.GetProposer()), ) @@ -222,3 +225,27 @@ func SyncInfoToProto(syncInfo hotstuff.SyncInfo) *SyncInfo { } return m } + +func convertHash(b []byte) (h hotstuff.Hash) { + if len(b) < len(h) { + copy(h[:], b) + } else { + h = *(*hotstuff.Hash)(b) + } + return h +} + +func unsafeStringToBytes(s string) []byte { + if s == "" { + return []byte{} + } + const max = 0x7fff0000 + if len(s) > max { + panic("string too long") + } + return (*[max]byte)(unsafe.Pointer((*reflect.StringHeader)(unsafe.Pointer(&s)).Data))[:len(s):len(s)] +} + +func unsafeBytesToString(b []byte) string { + return *(*string)(unsafe.Pointer(&b)) +} diff --git a/internal/proto/hotstuffpb/convert_test.go b/internal/proto/hotstuffpb/convert_test.go index 201d47ba..68121cf4 100644 --- a/internal/proto/hotstuffpb/convert_test.go +++ b/internal/proto/hotstuffpb/convert_test.go @@ -2,9 +2,12 @@ package hotstuffpb import ( "bytes" + "crypto/rand" + "io" + "testing" + "github.com/relab/hotstuff" "github.com/relab/hotstuff/modules" - "testing" "github.com/golang/mock/gomock" "github.com/relab/hotstuff/crypto" @@ -16,10 +19,12 @@ func TestConvertPartialCert(t *testing.T) { ctrl := gomock.NewController(t) key := testutil.GenerateECDSAKey(t) - builder := modules.NewConsensusBuilder(1, key) + builder := modules.NewBuilder(1, key) testutil.TestModules(t, ctrl, 1, key, &builder) hs := builder.Build() - signer := hs.Crypto() + + var signer modules.Crypto + hs.Get(&signer) want, err := signer.CreatePartialCert(hotstuff.GetGenesis()) if err != nil { @@ -44,7 +49,10 @@ func TestConvertQuorumCert(t *testing.T) { signatures := testutil.CreatePCs(t, b1, hl.Signers()) - want, err := hl[0].Crypto().CreateQuorumCert(b1, signatures) + var signer modules.Crypto + hl[0].Get(&signer) + + want, err := signer.CreateQuorumCert(b1, signatures) if err != nil { t.Fatal(err) } @@ -73,7 +81,7 @@ func TestConvertTimeoutCertBLS12(t *testing.T) { builders := testutil.CreateBuilders(t, ctrl, 4, testutil.GenerateKeys(t, 4, testutil.GenerateBLS12Key)...) for i := range builders { - builders[i].Register(crypto.New(bls12.New())) + builders[i].Add(crypto.New(bls12.New())) } hl := builders.Build() @@ -82,7 +90,42 @@ func TestConvertTimeoutCertBLS12(t *testing.T) { pb := TimeoutCertToProto(tc1) tc2 := TimeoutCertFromProto(pb) - if !hl[0].Crypto().VerifyTimeoutCert(tc2) { + var signer modules.Crypto + hl[0].Get(&signer) + + if !signer.VerifyTimeoutCert(tc2) { t.Fatal("Failed to verify timeout cert") } } + +func BenchmarkConvertHash(b *testing.B) { + s := make([]byte, 32) + _, err := io.ReadFull(rand.Reader, s) + if err != nil { + b.Fatal(err) + } + + var h hotstuff.Hash + + for i := 0; i < b.N; i++ { + h = convertHash(s) + } + + _ = h +} + +func BenchmarkCopyHash(b *testing.B) { + s := make([]byte, 32) + _, err := io.ReadFull(rand.Reader, s) + if err != nil { + b.Fatal(err) + } + + var h hotstuff.Hash + + for i := 0; i < b.N; i++ { + copy(h[:], s) + } + + _ = h +} diff --git a/internal/testutil/testutil.go b/internal/testutil/testutil.go index d85b04f0..be5b4d2c 100644 --- a/internal/testutil/testutil.go +++ b/internal/testutil/testutil.go @@ -9,6 +9,7 @@ import ( "time" "github.com/relab/hotstuff/consensus" + "github.com/relab/hotstuff/eventloop" "github.com/relab/hotstuff/modules" "github.com/golang/mock/gomock" @@ -26,7 +27,7 @@ import ( ) // TestModules registers default modules for testing to the given builder. -func TestModules(t *testing.T, ctrl *gomock.Controller, id hotstuff.ID, privkey hotstuff.PrivateKey, builder *modules.ConsensusBuilder) { +func TestModules(t *testing.T, ctrl *gomock.Controller, id hotstuff.ID, privkey hotstuff.PrivateKey, builder *modules.Builder) { t.Helper() acceptor := mocks.NewMockAcceptor(ctrl) @@ -36,6 +37,8 @@ func TestModules(t *testing.T, ctrl *gomock.Controller, id hotstuff.ID, privkey executor := mocks.NewMockExecutor(ctrl) executor.EXPECT().Exec(gomock.AssignableToTypeOf(hotstuff.Command(""))).AnyTimes() + forkHandler := mocks.NewMockForkHandler(ctrl) + commandQ := mocks.NewMockCommandQueue(ctrl) commandQ.EXPECT().Get(gomock.Any()).AnyTimes().Return(hotstuff.Command("foo"), true) @@ -45,15 +48,12 @@ func TestModules(t *testing.T, ctrl *gomock.Controller, id hotstuff.ID, privkey config.EXPECT().Len().AnyTimes().Return(1) config.EXPECT().QuorumSize().AnyTimes().Return(3) - replica := CreateMockReplica(t, ctrl, id, privkey.Public()) - ConfigAddReplica(t, config, replica) - config.EXPECT().Replicas().AnyTimes().Return((map[hotstuff.ID]modules.Replica{1: replica})) - synchronizer := mocks.NewMockSynchronizer(ctrl) synchronizer.EXPECT().Start(gomock.Any()).AnyTimes() synchronizer.EXPECT().ViewContext().AnyTimes().Return(context.Background()) - builder.Register( + builder.Add( + eventloop.New(100), logging.New(fmt.Sprintf("hs%d", id)), blockchain.New(), mocks.NewMockConsensus(ctrl), @@ -63,16 +63,17 @@ func TestModules(t *testing.T, ctrl *gomock.Controller, id hotstuff.ID, privkey config, signer, acceptor, - executor, + modules.ExtendedExecutor(executor), commandQ, + modules.ExtendedForkHandler(forkHandler), ) } // BuilderList is a helper type to perform actions on a set of builders. -type BuilderList []*modules.ConsensusBuilder +type BuilderList []*modules.Builder // HotStuffList is a helper type to perform actions on a set of HotStuff instances. -type HotStuffList []*modules.ConsensusCore +type HotStuffList []*modules.Core // Build calls Build() for all of the builders. func (bl BuilderList) Build() HotStuffList { @@ -87,7 +88,7 @@ func (bl BuilderList) Build() HotStuffList { func (hl HotStuffList) Signers() (signers []modules.Crypto) { signers = make([]modules.Crypto, len(hl)) for i, hs := range hl { - signers[i] = hs.Crypto() + hs.Get(&signers[i]) } return signers } @@ -96,7 +97,7 @@ func (hl HotStuffList) Signers() (signers []modules.Crypto) { func (hl HotStuffList) Verifiers() (verifiers []modules.Crypto) { verifiers = make([]modules.Crypto, len(hl)) for i, hs := range hl { - verifiers[i] = hs.Crypto() + hs.Get(&verifiers[i]) } return verifiers } @@ -105,7 +106,9 @@ func (hl HotStuffList) Verifiers() (verifiers []modules.Crypto) { func (hl HotStuffList) Keys() (keys []hotstuff.PrivateKey) { keys = make([]hotstuff.PrivateKey, len(hl)) for i, hs := range hl { - keys[i] = hs.PrivateKey() + var opts *modules.Options + hs.Get(&opts) + keys[i] = opts.PrivateKey() } return keys } @@ -114,7 +117,7 @@ func (hl HotStuffList) Keys() (keys []hotstuff.PrivateKey) { func CreateBuilders(t *testing.T, ctrl *gomock.Controller, n int, keys ...hotstuff.PrivateKey) (builders BuilderList) { t.Helper() network := twins.NewSimpleNetwork() - builders = make([]*modules.ConsensusBuilder, n) + builders = make([]*modules.Builder, n) for i := 0; i < n; i++ { id := hotstuff.ID(i + 1) var key hotstuff.PrivateKey @@ -125,63 +128,14 @@ func CreateBuilders(t *testing.T, ctrl *gomock.Controller, n int, keys ...hotstu } builder := network.GetNodeBuilder(twins.NodeID{ReplicaID: id, NetworkID: uint32(id)}, key) + builder.Add(network.NewConfiguration()) TestModules(t, ctrl, id, key, &builder) - builder.Register(network.NewConfiguration()) + builder.Add(network.NewConfiguration()) builders[i] = &builder } return builders } -// CreateMockConfigurationWithReplicas creates a configuration with n replicas. -func CreateMockConfigurationWithReplicas(t *testing.T, ctrl *gomock.Controller, n int, keys ...hotstuff.PrivateKey) (*mocks.MockConfiguration, []*mocks.MockReplica) { - t.Helper() - cfg := mocks.NewMockConfiguration(ctrl) - replicas := make([]*mocks.MockReplica, n) - if len(keys) == 0 { - keys = make([]hotstuff.PrivateKey, 0, n) - } - for i := 0; i < n; i++ { - if len(keys) <= i { - keys = append(keys, GenerateECDSAKey(t)) - } - replicas[i] = CreateMockReplica(t, ctrl, hotstuff.ID(i+1), keys[i].Public()) - ConfigAddReplica(t, cfg, replicas[i]) - } - cfg.EXPECT().Len().AnyTimes().Return(len(replicas)) - cfg.EXPECT().QuorumSize().AnyTimes().Return(hotstuff.QuorumSize(len(replicas))) - return cfg, replicas -} - -// CreateMockReplica returns a mock of a consensus.Replica. -func CreateMockReplica(t *testing.T, ctrl *gomock.Controller, id hotstuff.ID, key hotstuff.PublicKey) *mocks.MockReplica { - t.Helper() - - replica := mocks.NewMockReplica(ctrl) - replica. - EXPECT(). - ID(). - AnyTimes(). - Return(id) - replica. - EXPECT(). - PublicKey(). - AnyTimes(). - Return(key) - - return replica -} - -// ConfigAddReplica adds a mock replica to a mock configuration. -func ConfigAddReplica(t *testing.T, cfg *mocks.MockConfiguration, replica *mocks.MockReplica) { - t.Helper() - - cfg. - EXPECT(). - Replica(replica.ID()). - AnyTimes(). - Return(replica, true) -} - // CreateTCPListener creates a net.Listener on a random port. func CreateTCPListener(t *testing.T) net.Listener { t.Helper() diff --git a/leaderrotation/carousel.go b/leaderrotation/carousel.go index 3ee32dd2..68026fb4 100644 --- a/leaderrotation/carousel.go +++ b/leaderrotation/carousel.go @@ -4,6 +4,7 @@ import ( "math/rand" "github.com/relab/hotstuff" + "github.com/relab/hotstuff/logging" "github.com/relab/hotstuff/modules" "golang.org/x/exp/slices" ) @@ -13,31 +14,41 @@ func init() { } type carousel struct { - mods *modules.ConsensusCore + blockChain modules.BlockChain + configuration modules.Configuration + consensus modules.Consensus + opts *modules.Options + logger logging.Logger } -func (c *carousel) InitModule(mods *modules.ConsensusCore, _ *modules.OptionsBuilder) { - c.mods = mods +func (c *carousel) InitModule(mods *modules.Core) { + mods.GetAll( + &c.blockChain, + &c.configuration, + &c.consensus, + &c.opts, + &c.logger, + ) } func (c carousel) GetLeader(round hotstuff.View) hotstuff.ID { - commitHead := c.mods.Consensus().CommittedBlock() + commitHead := c.consensus.CommittedBlock() if commitHead.QuorumCert().Signature() == nil { - c.mods.Logger().Debug("in startup; using round-robin") - return chooseRoundRobin(round, c.mods.Configuration().Len()) + c.logger.Debug("in startup; using round-robin") + return chooseRoundRobin(round, c.configuration.Len()) } - if commitHead.View() != round-hotstuff.View(c.mods.Consensus().ChainLength()) { - c.mods.Logger().Debugf("fallback to round-robin (view=%d, commitHead=%d)", round, commitHead.View()) - return chooseRoundRobin(round, c.mods.Configuration().Len()) + if commitHead.View() != round-hotstuff.View(c.consensus.ChainLength()) { + c.logger.Debugf("fallback to round-robin (view=%d, commitHead=%d)", round, commitHead.View()) + return chooseRoundRobin(round, c.configuration.Len()) } - c.mods.Logger().Debug("proceeding with carousel") + c.logger.Debug("proceeding with carousel") var ( block = commitHead - f = hotstuff.NumFaulty(c.mods.Configuration().Len()) + f = hotstuff.NumFaulty(c.configuration.Len()) i = 0 lastAuthors = hotstuff.NewIDSet() ok = true @@ -45,11 +56,11 @@ func (c carousel) GetLeader(round hotstuff.View) hotstuff.ID { for ok && i < f && block != hotstuff.GetGenesis() { lastAuthors.Add(block.Proposer()) - block, ok = c.mods.BlockChain().Get(block.Parent()) + block, ok = c.blockChain.Get(block.Parent()) i++ } - candidates := make([]hotstuff.ID, 0, c.mods.Configuration().Len()-f) + candidates := make([]hotstuff.ID, 0, c.configuration.Len()-f) commitHead.QuorumCert().Signature().Participants().ForEach(func(id hotstuff.ID) { if !lastAuthors.Contains(id) { @@ -58,11 +69,11 @@ func (c carousel) GetLeader(round hotstuff.View) hotstuff.ID { }) slices.Sort(candidates) - seed := c.mods.Options().SharedRandomSeed() + int64(round) + seed := c.opts.SharedRandomSeed() + int64(round) rnd := rand.New(rand.NewSource(seed)) leader := candidates[rnd.Int()%len(candidates)] - c.mods.Logger().Debugf("chose id %d", leader) + c.logger.Debugf("chose id %d", leader) return leader } diff --git a/leaderrotation/fixed.go b/leaderrotation/fixed.go index 93150005..59eeb8d9 100644 --- a/leaderrotation/fixed.go +++ b/leaderrotation/fixed.go @@ -15,6 +15,10 @@ type fixed struct { leader hotstuff.ID } +func (*fixed) New() fixed { + return fixed{} +} + // GetLeader returns the id of the leader in the given view func (f fixed) GetLeader(_ hotstuff.View) hotstuff.ID { return f.leader @@ -22,5 +26,5 @@ func (f fixed) GetLeader(_ hotstuff.View) hotstuff.ID { // NewFixed returns a new fixed-leader leader rotation implementation. func NewFixed(leader hotstuff.ID) modules.LeaderRotation { - return fixed{leader} + return fixed{leader: leader} } diff --git a/leaderrotation/reputation.go b/leaderrotation/reputation.go index 3a8be864..fd46b391 100644 --- a/leaderrotation/reputation.go +++ b/leaderrotation/reputation.go @@ -7,6 +7,7 @@ import ( "golang.org/x/exp/slices" "github.com/relab/hotstuff" + "github.com/relab/hotstuff/logging" "github.com/relab/hotstuff/modules" ) @@ -17,30 +18,38 @@ func init() { type reputationsMap map[hotstuff.ID]float64 type repBased struct { - mods *modules.ConsensusCore + configuration modules.Configuration + consensus modules.Consensus + opts *modules.Options + logger logging.Logger prevCommitHead *hotstuff.Block reputations reputationsMap // latest reputations } -// InitModule gives the module a reference to the ConsensusCore object. +// InitModule gives the module a reference to the Core object. // It also allows the module to set module options using the OptionsBuilder -func (r *repBased) InitModule(mods *modules.ConsensusCore, _ *modules.OptionsBuilder) { - r.mods = mods +func (r *repBased) InitModule(mods *modules.Core) { + mods.GetAll( + &r.configuration, + &r.consensus, + &r.opts, + &r.logger, + ) } // TODO: should GetLeader be thread-safe? // GetLeader returns the id of the leader in the given view func (r *repBased) GetLeader(view hotstuff.View) hotstuff.ID { - block := r.mods.Consensus().CommittedBlock() - if block.View() > view-hotstuff.View(r.mods.Consensus().ChainLength()) { + block := r.consensus.CommittedBlock() + if block.View() > view-hotstuff.View(r.consensus.ChainLength()) { // TODO: it could be possible to lookup leaders for older views if we // store a copy of the reputations in a metadata field of each block. - r.mods.Logger().Error("looking up leaders of old views is not supported") + r.logger.Error("looking up leaders of old views is not supported") return 0 } - numReplicas := r.mods.Configuration().Len() + numReplicas := r.configuration.Len() // use round-robin for the first few views until we get a signature if block.QuorumCert().Signature() == nil { return chooseRoundRobin(view, numReplicas) @@ -75,19 +84,19 @@ func (r *repBased) GetLeader(view hotstuff.View) hotstuff.ID { r.prevCommitHead = block } - r.mods.Logger().Debug(weights) + r.logger.Debug(weights) chooser, err := wr.NewChooser(weights...) if err != nil { - r.mods.Logger().Error("weightedrand error: ", err) + r.logger.Error("weightedrand error: ", err) return 0 } - seed := r.mods.Options().SharedRandomSeed() + int64(view) + seed := r.opts.SharedRandomSeed() + int64(view) rnd := rand.New(rand.NewSource(seed)) leader := chooser.PickSource(rnd).(hotstuff.ID) - r.mods.Logger().Debugf("picked leader %d for view %d using seed %d", leader, view, seed) + r.logger.Debugf("picked leader %d for view %d using seed %d", leader, view, seed) return leader } diff --git a/leaderrotation/roundrobin.go b/leaderrotation/roundrobin.go index 437cbbad..8b2523e7 100644 --- a/leaderrotation/roundrobin.go +++ b/leaderrotation/roundrobin.go @@ -10,20 +10,18 @@ func init() { } type roundRobin struct { - mods *modules.ConsensusCore + configuration modules.Configuration } -// InitModule gives the module a reference to the ConsensusCore object. -// It also allows the module to set module options using the OptionsBuilder. -func (rr *roundRobin) InitModule(mods *modules.ConsensusCore, _ *modules.OptionsBuilder) { - rr.mods = mods +func (rr *roundRobin) InitModule(mods *modules.Core) { + mods.Get(&rr.configuration) } // GetLeader returns the id of the leader in the given view func (rr roundRobin) GetLeader(view hotstuff.View) hotstuff.ID { // TODO: does not support reconfiguration // assume IDs start at 1 - return chooseRoundRobin(view, rr.mods.Configuration().Len()) + return chooseRoundRobin(view, rr.configuration.Len()) } // NewRoundRobin returns a new round-robin leader rotation implementation. diff --git a/metrics/clientlatency.go b/metrics/clientlatency.go index 57fc4feb..338b3664 100644 --- a/metrics/clientlatency.go +++ b/metrics/clientlatency.go @@ -4,6 +4,8 @@ import ( "time" "github.com/relab/hotstuff/client" + "github.com/relab/hotstuff/eventloop" + "github.com/relab/hotstuff/logging" "github.com/relab/hotstuff/metrics/types" "github.com/relab/hotstuff/modules" ) @@ -16,24 +18,36 @@ func init() { // ClientLatency processes LatencyMeasurementEvents, and writes LatencyMeasurements to the metrics logger. type ClientLatency struct { - mods *modules.Core - wf Welford + metricsLogger Logger + opts *modules.Options + + wf Welford } // InitModule gives the module access to the other modules. func (lr *ClientLatency) InitModule(mods *modules.Core) { - lr.mods = mods + var ( + eventLoop *eventloop.EventLoop + logger logging.Logger + ) + + mods.GetAll( + &lr.metricsLogger, + &lr.opts, + &eventLoop, + &logger, + ) - lr.mods.EventLoop().RegisterHandler(client.LatencyMeasurementEvent{}, func(event any) { + eventLoop.RegisterHandler(client.LatencyMeasurementEvent{}, func(event any) { latencyEvent := event.(client.LatencyMeasurementEvent) lr.addLatency(latencyEvent.Latency) }) - lr.mods.EventLoop().RegisterObserver(types.TickEvent{}, func(event any) { + eventLoop.RegisterObserver(types.TickEvent{}, func(event any) { lr.tick(event.(types.TickEvent)) }) - lr.mods.Logger().Info("Client Latency metric enabled") + logger.Info("Client Latency metric enabled") } // AddLatency adds a latency data point to the current measurement. @@ -45,11 +59,11 @@ func (lr *ClientLatency) addLatency(latency time.Duration) { func (lr *ClientLatency) tick(tick types.TickEvent) { mean, variance, count := lr.wf.Get() event := &types.LatencyMeasurement{ - Event: types.NewClientEvent(uint32(lr.mods.ID()), time.Now()), + Event: types.NewClientEvent(uint32(lr.opts.ID()), time.Now()), Latency: mean, Variance: variance, Count: count, } - lr.mods.MetricsLogger().Log(event) + lr.metricsLogger.Log(event) lr.wf.Reset() } diff --git a/modules/datalogger.go b/metrics/datalogger.go similarity index 79% rename from modules/datalogger.go rename to metrics/datalogger.go index d1506a89..573f98d2 100644 --- a/modules/datalogger.go +++ b/metrics/datalogger.go @@ -1,30 +1,33 @@ -package modules +package metrics import ( "fmt" "io" "sync" + "github.com/relab/hotstuff/logging" + "github.com/relab/hotstuff/modules" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/anypb" ) -// MetricsLogger logs data in protobuf message format. -type MetricsLogger interface { +// Logger logs data in protobuf message format. +type Logger interface { Log(proto.Message) io.Closer } type jsonLogger struct { + logger logging.Logger + mut sync.Mutex - mods *Core wr io.Writer first bool } // NewJSONLogger returns a new metrics logger that logs to the specified writer. -func NewJSONLogger(wr io.Writer) (MetricsLogger, error) { +func NewJSONLogger(wr io.Writer) (Logger, error) { _, err := io.WriteString(wr, "[\n") if err != nil { return nil, fmt.Errorf("failed to write start of JSON array: %v", err) @@ -33,8 +36,8 @@ func NewJSONLogger(wr io.Writer) (MetricsLogger, error) { } // InitModule initializes the metrics logger module. -func (dl *jsonLogger) InitModule(mods *Core) { - dl.mods = mods +func (dl *jsonLogger) InitModule(mods *modules.Core) { + mods.Get(&dl.logger) } func (dl *jsonLogger) Log(msg proto.Message) { @@ -46,13 +49,13 @@ func (dl *jsonLogger) Log(msg proto.Message) { if any, ok = msg.(*anypb.Any); !ok { any, err = anypb.New(msg) if err != nil { - dl.mods.Logger().Errorf("failed to create Any message: %v", err) + dl.logger.Errorf("failed to create Any message: %v", err) return } } err = dl.write(any) if err != nil { - dl.mods.Logger().Errorf("failed to write message to log: %v", err) + dl.logger.Errorf("failed to write message to log: %v", err) } } @@ -94,6 +97,6 @@ func (nopLogger) Close() error { return nil } // NopLogger returns a metrics logger that discards any messages. // This is useful for testing and other situations where metrics logging is disabled. -func NopLogger() MetricsLogger { +func NopLogger() Logger { return nopLogger{} } diff --git a/metrics/throughput.go b/metrics/throughput.go index b182545f..4216eb22 100644 --- a/metrics/throughput.go +++ b/metrics/throughput.go @@ -5,6 +5,8 @@ import ( "github.com/relab/hotstuff" + "github.com/relab/hotstuff/eventloop" + "github.com/relab/hotstuff/logging" "github.com/relab/hotstuff/metrics/types" "github.com/relab/hotstuff/modules" "google.golang.org/protobuf/types/known/durationpb" @@ -18,22 +20,37 @@ func init() { // Throughput measures throughput in commits per second, and commands per second. type Throughput struct { - mods *modules.Core + metricsLogger Logger + opts *modules.Options + commitCount uint64 commandCount uint64 } // InitModule gives the module access to the other modules. func (t *Throughput) InitModule(mods *modules.Core) { - t.mods = mods - t.mods.EventLoop().RegisterHandler(hotstuff.CommitEvent{}, func(event any) { + var ( + eventLoop *eventloop.EventLoop + logger logging.Logger + ) + + mods.GetAll( + &t.metricsLogger, + &t.opts, + &eventLoop, + &logger, + ) + + eventLoop.RegisterHandler(hotstuff.CommitEvent{}, func(event any) { commitEvent := event.(hotstuff.CommitEvent) t.recordCommit(commitEvent.Commands) }) - t.mods.EventLoop().RegisterObserver(types.TickEvent{}, func(event any) { + + eventLoop.RegisterObserver(types.TickEvent{}, func(event any) { t.tick(event.(types.TickEvent)) }) - t.mods.Logger().Info("Throughput metric enabled") + + logger.Info("Throughput metric enabled") } func (t *Throughput) recordCommit(commands int) { @@ -44,12 +61,12 @@ func (t *Throughput) recordCommit(commands int) { func (t *Throughput) tick(tick types.TickEvent) { now := time.Now() event := &types.ThroughputMeasurement{ - Event: types.NewReplicaEvent(uint32(t.mods.ID()), now), + Event: types.NewReplicaEvent(uint32(t.opts.ID()), now), Commits: t.commitCount, Commands: t.commandCount, Duration: durationpb.New(now.Sub(tick.LastTick)), } - t.mods.MetricsLogger().Log(event) + t.metricsLogger.Log(event) // reset count for next tick t.commandCount = 0 t.commitCount = 0 diff --git a/metrics/ticker.go b/metrics/ticker.go index d3153858..65b3f860 100644 --- a/metrics/ticker.go +++ b/metrics/ticker.go @@ -3,13 +3,13 @@ package metrics import ( "time" + "github.com/relab/hotstuff/eventloop" "github.com/relab/hotstuff/metrics/types" "github.com/relab/hotstuff/modules" ) // Ticker emits TickEvents on the metrics event loop. type Ticker struct { - mods *modules.Core tickerID int interval time.Duration lastTick time.Time @@ -22,8 +22,11 @@ func NewTicker(interval time.Duration) *Ticker { // InitModule gives the module access to the other modules. func (t *Ticker) InitModule(mods *modules.Core) { - t.mods = mods - t.tickerID = t.mods.EventLoop().AddTicker(t.interval, t.tick) + var eventLoop *eventloop.EventLoop + + mods.Get(&eventLoop) + + t.tickerID = eventLoop.AddTicker(t.interval, t.tick) } func (t *Ticker) tick(tickTime time.Time) any { diff --git a/metrics/timeouts.go b/metrics/timeouts.go index 77fd9511..24d03fa4 100644 --- a/metrics/timeouts.go +++ b/metrics/timeouts.go @@ -3,6 +3,8 @@ package metrics import ( "time" + "github.com/relab/hotstuff/eventloop" + "github.com/relab/hotstuff/logging" "github.com/relab/hotstuff/metrics/types" "github.com/relab/hotstuff/modules" "github.com/relab/hotstuff/synchronizer" @@ -16,22 +18,34 @@ func init() { // ViewTimeouts is a metric that measures the number of view timeouts that happen. type ViewTimeouts struct { - mods *modules.Core + metricsLogger Logger + opts *modules.Options + numViews uint64 numTimeouts uint64 } // InitModule gives the module access to the other modules. func (vt *ViewTimeouts) InitModule(mods *modules.Core) { - vt.mods = mods + var ( + eventLoop *eventloop.EventLoop + logger logging.Logger + ) + + mods.GetAll( + &vt.metricsLogger, + &vt.opts, + &eventLoop, + &logger, + ) - vt.mods.Logger().Info("ViewTimeouts metric enabled.") + logger.Info("ViewTimeouts metric enabled.") - vt.mods.EventLoop().RegisterHandler(synchronizer.ViewChangeEvent{}, func(event any) { + eventLoop.RegisterHandler(synchronizer.ViewChangeEvent{}, func(event any) { vt.viewChange(event.(synchronizer.ViewChangeEvent)) }) - vt.mods.EventLoop().RegisterObserver(types.TickEvent{}, func(event any) { + eventLoop.RegisterObserver(types.TickEvent{}, func(event any) { vt.tick(event.(types.TickEvent)) }) } @@ -44,8 +58,8 @@ func (vt *ViewTimeouts) viewChange(event synchronizer.ViewChangeEvent) { } func (vt *ViewTimeouts) tick(event types.TickEvent) { - vt.mods.MetricsLogger().Log(&types.ViewTimeouts{ - Event: types.NewReplicaEvent(uint32(vt.mods.ID()), time.Now()), + vt.metricsLogger.Log(&types.ViewTimeouts{ + Event: types.NewReplicaEvent(uint32(vt.opts.ID()), time.Now()), Views: vt.numViews, Timeouts: vt.numTimeouts, }) diff --git a/modules/core.go b/modules/core.go index 1c9fca2a..a3a7740b 100644 --- a/modules/core.go +++ b/modules/core.go @@ -23,133 +23,145 @@ // In general you should create an interface for your module if it is possible that someone might want to write their // own version of it in the future. // -// Finally, to set up the module system and its modules, you must create a CoreBuilder using the NewCoreBuilder function, +// Finally, to set up the module system and its modules, you must create a CoreBuilder using the NewBuilder function, // and then register all of the modules with the builder using the Register method. For example: // -// builder := NewCoreBuilder() -// // replace the logger -// builder.Register(logging.New("foo")) -// mods := builder.Build() +// builder := NewBuilder() +// // replace the logger +// builder.Add(logging.New("foo")) +// mods := builder.Build() // // If two modules satisfy the same interface, then the one that was registered last will be returned by the module system, // though note that both modules will be initialized if they implement the CoreModule interface. package modules import ( + "fmt" "reflect" "github.com/relab/hotstuff" - "github.com/relab/hotstuff/eventloop" - "github.com/relab/hotstuff/logging" ) -// CoreModule is an interface for modules that need access to a client. -type CoreModule interface { - // InitModule gives the module access to the other modules. +// Module is an interface for initializing modules. +type Module interface { InitModule(mods *Core) } // Core is the base of the module system. // It contains only a few core modules that are shared between replicas and clients. type Core struct { - id hotstuff.ID - - logger logging.Logger - metricsLogger MetricsLogger - eventLoop *eventloop.EventLoop - - modulesByType map[reflect.Type]any -} - -// ID returns the id of this client. -func (mods Core) ID() hotstuff.ID { - return mods.id + modules []any } -// Logger returns the logger. -func (mods Core) Logger() logging.Logger { - return mods.logger -} +// TryGet attempts to find a module for ptr. +// TryGet returns true if a module was stored in ptr, false otherwise. +// +// NOTE: ptr must be a non-nil pointer to a type that has been provided to the module system. +// +// Example: +// +// builder := modules.New() +// builder.Provide(MyModuleImpl{}, new(MyModule)) +// mods = builder.Build() +// +// var module MyModule +// if mods.TryGet(&module) { +// // success +// } +func (mods Core) TryGet(ptr any) bool { + v := reflect.ValueOf(ptr) + if !v.IsValid() { + panic("nil value given") + } + pt := v.Type() + if pt.Kind() != reflect.Ptr { + panic("only pointer values allowed") + } -// MetricsLogger returns the metrics logger. -func (mods Core) MetricsLogger() MetricsLogger { - if mods.metricsLogger == nil { - return NopLogger() + for _, m := range mods.modules { + mv := reflect.ValueOf(m) + if mv.Type().AssignableTo(pt.Elem()) { + v.Elem().Set(mv) + return true + } } - return mods.metricsLogger -} -// EventLoop returns the event loop. -func (mods Core) EventLoop() *eventloop.EventLoop { - return mods.eventLoop + return false } -// MetricsEventLoop returns the metrics event loop. -// The metrics event loop is used for processing of measurement data. +// Get finds a module for ptr. // -// Deprecated: The metrics event loop is no longer separate from the main event loop. Use EventLoop() instead. -func (mods Core) MetricsEventLoop() *eventloop.EventLoop { - return mods.EventLoop() +// NOTE: ptr must be a non-nil pointer to a type that has been provided to the module system. +// Get panics if ptr is not a pointer, or if a compatible module is not found. +// +// Example: +// +// builder := modules.New() +// builder.Provide(MyModuleImpl{}, new(MyModule)) +// mods = builder.Build() +// +// var module MyModule +// mods.Get(&module) +func (mods *Core) Get(ptr any) { + if !mods.TryGet(ptr) { + panic(fmt.Sprintf("module of type %s not found", reflect.TypeOf(ptr).Elem())) + } } -// GetModuleByType makes it possible to get a module based on its real type. -// This is useful for getting modules that do not implement any known module interface. -// The method returns true if a module was found, false otherwise. -// -// NOTE: dest MUST be a pointer to a variable of the desired type. -// For example: -// var module MyModule -// if mods.GetModuleByType(&module) { ... } -func (mods Core) GetModuleByType(dest any) bool { - outType := reflect.TypeOf(dest) - if outType.Kind() != reflect.Ptr { - panic("invalid argument: out must be a non-nil pointer to an interface variable") - } - targetType := outType.Elem() - if m, ok := mods.modulesByType[targetType]; ok { - reflect.ValueOf(dest).Elem().Set(reflect.ValueOf(m)) - return true +// GetAll finds a module for all the given pointers. +// +// NOTE: pointers must only contain non-nil pointers to types that have been provided to the module system. +// GetAll panics if one of the given pointers is not a pointer, or if a compatible module is not found. +func (mods *Core) GetAll(pointers ...any) { + for _, ptr := range pointers { + mods.Get(ptr) } - return false } -// CoreBuilder is a helper for setting up client modules. -type CoreBuilder struct { - mods Core - modules []CoreModule +// Builder is a helper for setting up client modules. +type Builder struct { + core Core + modules []Module + opts *Options } -// NewCoreBuilder returns a new builder. -func NewCoreBuilder(id hotstuff.ID) CoreBuilder { - bl := CoreBuilder{mods: Core{ - id: id, - logger: logging.New(""), - eventLoop: eventloop.New(1000), - modulesByType: make(map[reflect.Type]any), - }} +// NewBuilder returns a new builder. +func NewBuilder(id hotstuff.ID, pk hotstuff.PrivateKey) Builder { + bl := Builder{ + opts: &Options{ + id: id, + privateKey: pk, + connectionMetadata: make(map[string]string), + }, + } return bl } -// Register registers the modules with the builder. -func (b *CoreBuilder) Register(modules ...any) { +// Options returns the options module. +func (b *Builder) Options() *Options { + return b.opts +} + +// Add adds modules to the builder. +func (b *Builder) Add(modules ...any) { + b.core.modules = append(b.core.modules, modules...) for _, module := range modules { - if m, ok := module.(logging.Logger); ok { - b.mods.logger = m - } - if m, ok := module.(MetricsLogger); ok { - b.mods.metricsLogger = m - } - if m, ok := module.(CoreModule); ok { + if m, ok := module.(Module); ok { b.modules = append(b.modules, m) } - b.mods.modulesByType[reflect.TypeOf(module)] = module } } -// Build initializes all registered modules and returns the Core object. -func (b *CoreBuilder) Build() *Core { +// Build initializes all added modules and returns the Core object. +func (b *Builder) Build() *Core { + // reverse the order of the added modules so that TryGet will find the latest first. + for i, j := 0, len(b.core.modules)-1; i < j; i, j = i+1, j-1 { + b.core.modules[i], b.core.modules[j] = b.core.modules[j], b.core.modules[i] + } + // add the Options last so that it can be overridden by user. + b.Add(b.opts) for _, module := range b.modules { - module.InitModule(&b.mods) + module.InitModule(&b.core) } - return &b.mods + return &b.core } diff --git a/modules/module_test.go b/modules/module_test.go new file mode 100644 index 00000000..2f5d583d --- /dev/null +++ b/modules/module_test.go @@ -0,0 +1,69 @@ +package modules_test + +import ( + "testing" + + "github.com/relab/hotstuff/modules" +) + +type Counter interface { + Increment(name string) + Count(name string) int +} + +type counterImpl struct { + counters map[string]int +} + +func (c counterImpl) Increment(name string) { c.counters[name]++ } +func (c counterImpl) Count(name string) int { return c.counters[name] } + +func NewCounter() *counterImpl { + return &counterImpl{ + counters: make(map[string]int), + } +} + +type Greeter interface { + Greet(name string) string +} + +type greeterImpl struct { + // declares dependencies on other modules + counter Counter +} + +func (g greeterImpl) Greet(name string) string { + g.counter.Increment(name) + return "Hello, " + name +} + +func NewGreeter() *greeterImpl { + return &greeterImpl{} +} + +func (g *greeterImpl) InitModule(mods *modules.Core) { + mods.Get(&g.counter) +} + +func TestModule(t *testing.T) { + builder := modules.NewBuilder(0, nil) + builder.Add(NewCounter(), NewGreeter()) + + mods := builder.Build() + + var ( + counter Counter + greeter Greeter + ) + + mods.GetAll(&counter, &greeter) + + if greeter.Greet("John") != "Hello, John" { + t.Fail() + } + + if counter.Count("John") != 1 { + t.Fail() + } +} diff --git a/modules/modules.go b/modules/modules.go index 505f0a1d..2e48988e 100644 --- a/modules/modules.go +++ b/modules/modules.go @@ -6,195 +6,8 @@ import ( "github.com/relab/hotstuff" ) -// ConsensusCore contains the modules that together implement consensus. -type ConsensusCore struct { - *Core - - privateKey hotstuff.PrivateKey - opts Options - - acceptor Acceptor - blockChain BlockChain - commandQueue CommandQueue - config Configuration - consensus Consensus - executor ExecutorExt - leaderRotation LeaderRotation - crypto Crypto - synchronizer Synchronizer - forkHandler ForkHandlerExt - handel Handel -} - -// Run starts both event loops using the provided context and returns when both event loops have exited. -func (mods *ConsensusCore) Run(ctx context.Context) { - mods.EventLoop().Run(ctx) -} - -// PrivateKey returns the private key. -func (mods *ConsensusCore) PrivateKey() hotstuff.PrivateKey { - return mods.privateKey -} - -// Options returns the current configuration settings. -func (mods *ConsensusCore) Options() *Options { - return &mods.opts -} - -// Acceptor returns the acceptor. -func (mods *ConsensusCore) Acceptor() Acceptor { - return mods.acceptor -} - -// BlockChain returns the block chain. -func (mods *ConsensusCore) BlockChain() BlockChain { - return mods.blockChain -} - -// CommandQueue returns the command queue. -func (mods *ConsensusCore) CommandQueue() CommandQueue { - return mods.commandQueue -} - -// Configuration returns the configuration of replicas. -func (mods *ConsensusCore) Configuration() Configuration { - return mods.config -} - -// Consensus returns the consensus implementation. -func (mods *ConsensusCore) Consensus() Consensus { - return mods.consensus -} - -// Executor returns the executor. -func (mods *ConsensusCore) Executor() ExecutorExt { - return mods.executor -} - -// LeaderRotation returns the leader rotation implementation. -func (mods *ConsensusCore) LeaderRotation() LeaderRotation { - return mods.leaderRotation -} - -// Crypto returns the cryptography implementation. -func (mods *ConsensusCore) Crypto() Crypto { - return mods.crypto -} - -// Synchronizer returns the view synchronizer implementation. -func (mods *ConsensusCore) Synchronizer() Synchronizer { - return mods.synchronizer -} - -// ForkHandler returns the module responsible for handling forked blocks. -func (mods *ConsensusCore) ForkHandler() ForkHandlerExt { - return mods.forkHandler -} - -// Handel returns the Handel implementation. -func (mods *ConsensusCore) Handel() Handel { - return mods.handel -} - -// ConsensusBuilder is a helper for constructing a ConsensusCore instance. -type ConsensusBuilder struct { - baseBuilder CoreBuilder - mods *ConsensusCore - cfg OptionsBuilder - modules []ConsensusModule -} - -// NewConsensusBuilder creates a new ConsensusBuilder. -func NewConsensusBuilder(id hotstuff.ID, privateKey hotstuff.PrivateKey) ConsensusBuilder { - bl := ConsensusBuilder{ - baseBuilder: NewCoreBuilder(id), - mods: &ConsensusCore{ - privateKey: privateKey, - }, - } - // using a pointer here will allow settings to be readable within InitModule - bl.cfg.opts = &bl.mods.opts - bl.cfg.opts.connectionMetadata = make(map[string]string) - return bl -} - -// Register adds modules to the HotStuff object and initializes them. -// ConsensusCore are assigned to fields based on the interface they implement. -// If only the Module interface is implemented, the InitModule function will be called, but -// the HotStuff object will not save a reference to the module. -// Register will overwrite existing modules if the same type is registered twice. -func (b *ConsensusBuilder) Register(mods ...any) { //nolint:gocyclo - for _, module := range mods { - b.baseBuilder.Register(module) - if m, ok := module.(Acceptor); ok { - b.mods.acceptor = m - } - if m, ok := module.(BlockChain); ok { - b.mods.blockChain = m - } - if m, ok := module.(CommandQueue); ok { - b.mods.commandQueue = m - } - if m, ok := module.(Configuration); ok { - b.mods.config = m - } - if m, ok := module.(Consensus); ok { - b.mods.consensus = m - } - if m, ok := module.(ExecutorExt); ok { - b.mods.executor = m - } - if m, ok := module.(Executor); ok { - b.mods.executor = executorWrapper{m} - } - if m, ok := module.(LeaderRotation); ok { - b.mods.leaderRotation = m - } - if m, ok := module.(Crypto); ok { - b.mods.crypto = m - } - if m, ok := module.(Synchronizer); ok { - b.mods.synchronizer = m - } - if m, ok := module.(ForkHandlerExt); ok { - b.mods.forkHandler = m - } - if m, ok := module.(ForkHandler); ok { - b.mods.forkHandler = forkHandlerWrapper{m} - } - if m, ok := module.(Handel); ok { - b.mods.handel = m - } - if m, ok := module.(ConsensusModule); ok { - b.modules = append(b.modules, m) - } - } -} - -// OptionsBuilder returns a pointer to the options builder. -// This can be used to configure runtime options. -func (b *ConsensusBuilder) OptionsBuilder() *OptionsBuilder { - return &b.cfg -} - -// Build initializes all modules and returns the HotStuff object. -func (b *ConsensusBuilder) Build() *ConsensusCore { - b.mods.Core = b.baseBuilder.Build() - for _, module := range b.modules { - module.InitModule(b.mods, &b.cfg) - } - return b.mods -} - // Module interfaces -// ConsensusModule is an interface that can be implemented by types that need access to other consensus modules. -type ConsensusModule interface { - // InitModule gives the module a reference to the ConsensusCore object. - // It also allows the module to set module options using the OptionsBuilder. - InitModule(mods *ConsensusCore, _ *OptionsBuilder) -} - //go:generate mockgen -destination=../internal/mocks/cmdqueue_mock.go -package=mocks . CommandQueue // CommandQueue is a queue of commands to be proposed. @@ -233,6 +46,8 @@ type ExecutorExt interface { Exec(block *hotstuff.Block) } +//go:generate mockgen -destination=../internal/mocks/forkhandler_mock.go -package=mocks . ForkHandler + // ForkHandler handles commands that do not get committed due to a forked blockchain. // // TODO: think of a better name/interface @@ -341,8 +156,6 @@ type Synchronizer interface { AdvanceView(hotstuff.SyncInfo) // View returns the current view. View() hotstuff.View - // ViewContext returns a context that is cancelled at the end of the view. - ViewContext() context.Context // HighQC returns the highest known QC. HighQC() hotstuff.QuorumCert // LeafBlock returns the current leaf block. @@ -357,18 +170,40 @@ type Handel interface { Begin(s hotstuff.PartialCert) } +// ExtendedExecutor turns the given Executor into an ExecutorExt. +func ExtendedExecutor(executor Executor) ExecutorExt { + return executorWrapper{executor} +} + type executorWrapper struct { executor Executor } +func (ew executorWrapper) InitModule(mods *Core) { + if m, ok := ew.executor.(Module); ok { + m.InitModule(mods) + } +} + func (ew executorWrapper) Exec(block *hotstuff.Block) { ew.executor.Exec(block.Command()) } +// ExtendedForkHandler turns the given ForkHandler into a ForkHandlerExt. +func ExtendedForkHandler(forkHandler ForkHandler) ForkHandlerExt { + return forkHandlerWrapper{forkHandler} +} + type forkHandlerWrapper struct { forkHandler ForkHandler } +func (fhw forkHandlerWrapper) InitModule(mods *Core) { + if m, ok := fhw.forkHandler.(Module); ok { + m.InitModule(mods) + } +} + func (fhw forkHandlerWrapper) Fork(block *hotstuff.Block) { fhw.forkHandler.Fork(block.Command()) } diff --git a/modules/options.go b/modules/options.go index f8b037f2..94a87db3 100644 --- a/modules/options.go +++ b/modules/options.go @@ -1,7 +1,12 @@ package modules +import "github.com/relab/hotstuff" + // Options stores runtime configuration settings. type Options struct { + id hotstuff.ID + privateKey hotstuff.PrivateKey + shouldUseAggQC bool shouldUseHandel bool shouldVerifyVotesSync bool @@ -10,56 +15,61 @@ type Options struct { connectionMetadata map[string]string } +// ID returns the ID. +func (opts Options) ID() hotstuff.ID { + return opts.id +} + +// PrivateKey returns the private key. +func (opts Options) PrivateKey() hotstuff.PrivateKey { + return opts.privateKey +} + // ShouldUseAggQC returns true if aggregated quorum certificates should be used. // This is true for Fast-Hotstuff: https://arxiv.org/abs/2010.11454 -func (c Options) ShouldUseAggQC() bool { - return c.shouldUseAggQC +func (opts Options) ShouldUseAggQC() bool { + return opts.shouldUseAggQC } // ShouldUseHandel returns true if the Handel signature aggregation protocol should be used. -func (c Options) ShouldUseHandel() bool { - return c.shouldUseHandel +func (opts Options) ShouldUseHandel() bool { + return opts.shouldUseHandel } // ShouldVerifyVotesSync returns true if votes should be verified synchronously. // Enabling this should make the voting machine process votes synchronously. -func (c Options) ShouldVerifyVotesSync() bool { - return c.shouldVerifyVotesSync +func (opts Options) ShouldVerifyVotesSync() bool { + return opts.shouldVerifyVotesSync } // SharedRandomSeed returns a random number that is shared between all replicas. -func (c Options) SharedRandomSeed() int64 { - return c.sharedRandomSeed +func (opts Options) SharedRandomSeed() int64 { + return opts.sharedRandomSeed } // ConnectionMetadata returns the metadata map that is sent when connecting to other replicas. -func (c Options) ConnectionMetadata() map[string]string { - return c.connectionMetadata -} - -// OptionsBuilder is used to set the values of immutable configuration settings. -type OptionsBuilder struct { - opts *Options +func (opts Options) ConnectionMetadata() map[string]string { + return opts.connectionMetadata } // SetShouldUseAggQC sets the ShouldUseAggQC setting to true. -func (builder *OptionsBuilder) SetShouldUseAggQC() { - builder.opts.shouldUseAggQC = true +func (opts *Options) SetShouldUseAggQC() { + opts.shouldUseAggQC = true } // SetShouldUseHandel sets the ShouldUseHandel setting to true. -func (builder *OptionsBuilder) SetShouldUseHandel() { - builder.opts.shouldUseHandel = true +func (opts *Options) SetShouldUseHandel() { + opts.shouldUseHandel = true } // SetShouldVerifyVotesSync sets the ShouldVerifyVotesSync setting to true. -func (builder *OptionsBuilder) SetShouldVerifyVotesSync() { - builder.opts.shouldVerifyVotesSync = true +func (opts *Options) SetShouldVerifyVotesSync() { + opts.shouldVerifyVotesSync = true } // SetSharedRandomSeed sets the shared random seed. -func (builder *OptionsBuilder) SetSharedRandomSeed(seed int64) { - builder.opts.sharedRandomSeed = seed +func (opts *Options) SetSharedRandomSeed(seed int64) { + opts.sharedRandomSeed = seed } // SetConnectionMetadata sets the value of a key in the connection metadata map. @@ -67,6 +77,6 @@ func (builder *OptionsBuilder) SetSharedRandomSeed(seed int64) { // NOTE: if the value contains binary data, the key must have the "-bin" suffix. // This is to make it compatible with GRPC metadata. // See: https://github.com/grpc/grpc-go/blob/master/Documentation/grpc-metadata.md#storing-binary-data-in-metadata -func (builder *OptionsBuilder) SetConnectionMetadata(key string, value string) { - builder.opts.connectionMetadata[key] = value +func (opts *Options) SetConnectionMetadata(key string, value string) { + opts.connectionMetadata[key] = value } diff --git a/replica/clientsrv.go b/replica/clientsrv.go index 2b3519eb..5e2f0e62 100644 --- a/replica/clientsrv.go +++ b/replica/clientsrv.go @@ -2,13 +2,16 @@ package replica import ( "crypto/sha256" - "github.com/relab/hotstuff" "hash" "net" "sync" + "github.com/relab/hotstuff" + "github.com/relab/gorums" + "github.com/relab/hotstuff/eventloop" "github.com/relab/hotstuff/internal/proto/clientpb" + "github.com/relab/hotstuff/logging" "github.com/relab/hotstuff/modules" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -18,8 +21,10 @@ import ( // clientSrv serves a client. type clientSrv struct { + eventLoop *eventloop.EventLoop + logger logging.Logger + mut sync.Mutex - mods *modules.Core srv *gorums.Server awaitingCmds map[cmdID]chan<- error cmdCache *cmdCache @@ -40,7 +45,10 @@ func newClientServer(conf Config, srvOpts []gorums.ServerOption) (srv *clientSrv // InitModule gives the module access to the other modules. func (srv *clientSrv) InitModule(mods *modules.Core) { - srv.mods = mods + mods.GetAll( + &srv.eventLoop, + &srv.logger, + ) srv.cmdCache.InitModule(mods) } @@ -57,7 +65,7 @@ func (srv *clientSrv) StartOnListener(lis net.Listener) { go func() { err := srv.srv.Serve(lis) if err != nil { - srv.mods.Logger().Error(err) + srv.logger.Error(err) } }() } @@ -84,11 +92,11 @@ func (srv *clientSrv) Exec(cmd hotstuff.Command) { batch := new(clientpb.Batch) err := proto.UnmarshalOptions{AllowPartial: true}.Unmarshal([]byte(cmd), batch) if err != nil { - srv.mods.Logger().Errorf("Failed to unmarshal command: %v", err) + srv.logger.Errorf("Failed to unmarshal command: %v", err) return } - srv.mods.EventLoop().AddEvent(hotstuff.CommitEvent{Commands: len(batch.GetCommands())}) + srv.eventLoop.AddEvent(hotstuff.CommitEvent{Commands: len(batch.GetCommands())}) for _, cmd := range batch.GetCommands() { _, _ = srv.hash.Write(cmd.Data) @@ -101,14 +109,14 @@ func (srv *clientSrv) Exec(cmd hotstuff.Command) { srv.mut.Unlock() } - srv.mods.Logger().Debugf("Hash: %.8x", srv.hash.Sum(nil)) + srv.logger.Debugf("Hash: %.8x", srv.hash.Sum(nil)) } func (srv *clientSrv) Fork(cmd hotstuff.Command) { batch := new(clientpb.Batch) err := proto.UnmarshalOptions{AllowPartial: true}.Unmarshal([]byte(cmd), batch) if err != nil { - srv.mods.Logger().Errorf("Failed to unmarshal command: %v", err) + srv.logger.Errorf("Failed to unmarshal command: %v", err) return } diff --git a/replica/cmdcache.go b/replica/cmdcache.go index 961a6488..f13ad6bf 100644 --- a/replica/cmdcache.go +++ b/replica/cmdcache.go @@ -3,17 +3,20 @@ package replica import ( "container/list" "context" - "github.com/relab/hotstuff" "sync" + "github.com/relab/hotstuff" + "github.com/relab/hotstuff/internal/proto/clientpb" + "github.com/relab/hotstuff/logging" "github.com/relab/hotstuff/modules" "google.golang.org/protobuf/proto" ) type cmdCache struct { + logger logging.Logger + mut sync.Mutex - mods *modules.Core c chan struct{} batchSize int serialNumbers map[uint32]uint64 // highest proposed serial number per client ID @@ -34,7 +37,7 @@ func newCmdCache(batchSize int) *cmdCache { // InitModule gives the module access to the other modules. func (c *cmdCache) InitModule(mods *modules.Core) { - c.mods = mods + mods.Get(&c.logger) } func (c *cmdCache) addCommand(cmd *clientpb.Command) { @@ -98,7 +101,7 @@ awaitBatch: // otherwise, we should have at least one command b, err := c.marshaler.Marshal(batch) if err != nil { - c.mods.Logger().Errorf("Failed to marshal batch: %v", err) + c.logger.Errorf("Failed to marshal batch: %v", err) return "", false } @@ -111,7 +114,7 @@ func (c *cmdCache) Accept(cmd hotstuff.Command) bool { batch := new(clientpb.Batch) err := c.unmarshaler.Unmarshal([]byte(cmd), batch) if err != nil { - c.mods.Logger().Errorf("Failed to unmarshal batch: %v", err) + c.logger.Errorf("Failed to unmarshal batch: %v", err) return false } @@ -133,7 +136,7 @@ func (c *cmdCache) Proposed(cmd hotstuff.Command) { batch := new(clientpb.Batch) err := c.unmarshaler.Unmarshal([]byte(cmd), batch) if err != nil { - c.mods.Logger().Errorf("Failed to unmarshal batch: %v", err) + c.logger.Errorf("Failed to unmarshal batch: %v", err) return } diff --git a/replica/replica.go b/replica/replica.go index 09f6fd9e..b8596bf7 100644 --- a/replica/replica.go +++ b/replica/replica.go @@ -5,9 +5,11 @@ import ( "context" "crypto/tls" "crypto/x509" - "github.com/relab/hotstuff/modules" "net" + "github.com/relab/hotstuff/eventloop" + "github.com/relab/hotstuff/modules" + "github.com/relab/gorums" "github.com/relab/hotstuff" "github.com/relab/hotstuff/backend" @@ -49,7 +51,7 @@ type Replica struct { clientSrv *clientSrv cfg *backend.Config hsSrv *backend.Server - hs *modules.ConsensusCore + hs *modules.Core execHandlers map[cmdID]func(*emptypb.Empty, error) cancel context.CancelFunc @@ -57,7 +59,7 @@ type Replica struct { } // New returns a new replica. -func New(conf Config, builder modules.ConsensusBuilder) (replica *Replica) { +func New(conf Config, builder modules.Builder) (replica *Replica) { clientSrvOpts := conf.ClientServerOptions if conf.TLS { @@ -98,11 +100,14 @@ func New(conf Config, builder modules.ConsensusBuilder) (replica *Replica) { } srv.cfg = backend.NewConfig(creds, managerOpts...) - builder.Register( - srv.cfg, // configuration - srv.hsSrv, // event handling - srv.clientSrv, // executor - srv.clientSrv.cmdCache, // acceptor and command queue + builder.Add( + srv.cfg, // configuration + srv.hsSrv, // event handling + + modules.ExtendedExecutor(srv.clientSrv), + modules.ExtendedForkHandler(srv.clientSrv), + srv.clientSrv.cmdCache, + srv.clientSrv.cmdCache, ) srv.hs = builder.Build() @@ -110,7 +115,7 @@ func New(conf Config, builder modules.ConsensusBuilder) (replica *Replica) { } // Modules returns the Modules object of this replica. -func (srv *Replica) Modules() *modules.ConsensusCore { +func (srv *Replica) Modules() *modules.Core { return srv.hs } @@ -144,8 +149,14 @@ func (srv *Replica) Stop() { // Run runs the replica until the context is cancelled. func (srv *Replica) Run(ctx context.Context) { - srv.hs.Synchronizer().Start(ctx) - srv.hs.Run(ctx) + var ( + synchronizer modules.Synchronizer + eventLoop *eventloop.EventLoop + ) + srv.hs.GetAll(&synchronizer, &eventLoop) + + synchronizer.Start(ctx) + eventLoop.Run(ctx) } // Close closes the connections and stops the servers used by the replica. diff --git a/synchronizer/context.go b/synchronizer/context.go new file mode 100644 index 00000000..299f63e5 --- /dev/null +++ b/synchronizer/context.go @@ -0,0 +1,44 @@ +package synchronizer + +import ( + "context" + + "github.com/relab/hotstuff" + "github.com/relab/hotstuff/eventloop" +) + +// This file provides several functions for creating contexts with lifespans that are tied to synchronizer events. + +// ViewContext returns a context that is cancelled at the end of a view. +// If view is nil or less than or equal to the current view, the context will be cancelled at the next view change. +// +// ViewContext should probably not be used for operations running on the event loop, because +func ViewContext(parent context.Context, eventLoop *eventloop.EventLoop, view *hotstuff.View) (context.Context, context.CancelFunc) { + ctx, cancel := context.WithCancel(parent) + + id := eventLoop.RegisterHandler(ViewChangeEvent{}, func(event any) { + if view == nil || event.(ViewChangeEvent).View >= *view { + cancel() + } + }, eventloop.RunAsync(), eventloop.WithPriority()) + + return ctx, func() { + eventLoop.UnregisterHandler(ViewChangeEvent{}, id) + cancel() + } +} + +// TimeoutContext returns a context that is cancelled either when a timeout occurs, or when the view changes. +func TimeoutContext(parent context.Context, eventLoop *eventloop.EventLoop) (context.Context, context.CancelFunc) { + // ViewContext handles view-change case. + ctx, cancel := ViewContext(parent, eventLoop, nil) + + id := eventLoop.RegisterHandler(TimeoutEvent{}, func(event any) { + cancel() + }, eventloop.RunAsync(), eventloop.WithPriority()) + + return ctx, func() { + eventLoop.UnregisterHandler(TimeoutEvent{}, id) + cancel() + } +} diff --git a/synchronizer/synchronizer.go b/synchronizer/synchronizer.go index 2c4a7ce3..95c920cf 100644 --- a/synchronizer/synchronizer.go +++ b/synchronizer/synchronizer.go @@ -3,8 +3,11 @@ package synchronizer import ( "context" "fmt" + "sync" "time" + "github.com/relab/hotstuff/eventloop" + "github.com/relab/hotstuff/logging" "github.com/relab/hotstuff/modules" "github.com/relab/hotstuff" @@ -12,8 +15,16 @@ import ( // Synchronizer synchronizes replicas to the same view. type Synchronizer struct { - mods *modules.ConsensusCore - + blockChain modules.BlockChain + consensus modules.Consensus + crypto modules.Crypto + configuration modules.Configuration + eventLoop *eventloop.EventLoop + leaderRotation modules.LeaderRotation + logger logging.Logger + opts *modules.Options + + mut sync.RWMutex // to protect the following currentView hotstuff.View highTC hotstuff.TimeoutCert highQC hotstuff.QuorumCert @@ -27,60 +38,57 @@ type Synchronizer struct { duration ViewDuration timer *time.Timer - viewCtx context.Context // a context that is cancelled at the end of the current view - cancelCtx context.CancelFunc - // map of collected timeout messages per view timeouts map[hotstuff.View]map[hotstuff.ID]hotstuff.TimeoutMsg } -// InitModule gives the module a reference to the ConsensusCore object. -// It also allows the module to set module options using the OptionsBuilder. -func (s *Synchronizer) InitModule(mods *modules.ConsensusCore, opts *modules.OptionsBuilder) { - if duration, ok := s.duration.(modules.ConsensusModule); ok { - duration.InitModule(mods, opts) - } - s.mods = mods +// InitModule initializes the synchronizer. +func (s *Synchronizer) InitModule(mods *modules.Core) { + mods.GetAll( + &s.blockChain, + &s.consensus, + &s.crypto, + &s.configuration, + &s.eventLoop, + &s.leaderRotation, + &s.logger, + &s.opts, + ) - s.mods.EventLoop().RegisterHandler(TimeoutEvent{}, func(event any) { + s.eventLoop.RegisterHandler(TimeoutEvent{}, func(event any) { timeoutView := event.(TimeoutEvent).View - if s.currentView == timeoutView { + if s.View() == timeoutView { s.OnLocalTimeout() } }) - s.mods.EventLoop().RegisterHandler(hotstuff.NewViewMsg{}, func(event any) { + s.eventLoop.RegisterHandler(hotstuff.NewViewMsg{}, func(event any) { newViewMsg := event.(hotstuff.NewViewMsg) s.OnNewView(newViewMsg) }) - s.mods.EventLoop().RegisterHandler(hotstuff.TimeoutMsg{}, func(event any) { + s.eventLoop.RegisterHandler(hotstuff.TimeoutMsg{}, func(event any) { timeoutMsg := event.(hotstuff.TimeoutMsg) s.OnRemoteTimeout(timeoutMsg) }) var err error - s.highQC, err = s.mods.Crypto().CreateQuorumCert(hotstuff.GetGenesis(), []hotstuff.PartialCert{}) + s.highQC, err = s.crypto.CreateQuorumCert(hotstuff.GetGenesis(), []hotstuff.PartialCert{}) if err != nil { panic(fmt.Errorf("unable to create empty quorum cert for genesis block: %v", err)) } - s.highTC, err = s.mods.Crypto().CreateTimeoutCert(hotstuff.View(0), []hotstuff.TimeoutMsg{}) + s.highTC, err = s.crypto.CreateTimeoutCert(hotstuff.View(0), []hotstuff.TimeoutMsg{}) if err != nil { panic(fmt.Errorf("unable to create empty timeout cert for view 0: %v", err)) } - } // New creates a new Synchronizer. func New(viewDuration ViewDuration) modules.Synchronizer { - ctx, cancel := context.WithCancel(context.Background()) return &Synchronizer{ leafBlock: hotstuff.GetGenesis(), currentView: 1, - viewCtx: ctx, - cancelCtx: cancel, - duration: viewDuration, timer: time.AfterFunc(0, func() {}), // dummy timer that will be replaced after start() is called @@ -92,8 +100,7 @@ func New(viewDuration ViewDuration) modules.Synchronizer { func (s *Synchronizer) Start(ctx context.Context) { s.timer = time.AfterFunc(s.duration.Duration(), func() { // The event loop will execute onLocalTimeout for us. - s.cancelCtx() - s.mods.EventLoop().AddEvent(TimeoutEvent{s.currentView}) + s.eventLoop.AddEvent(TimeoutEvent{s.View()}) }) go func() { @@ -102,8 +109,8 @@ func (s *Synchronizer) Start(ctx context.Context) { }() // start the initial proposal - if s.currentView == 1 && s.mods.LeaderRotation().GetLeader(s.currentView) == s.mods.ID() { - s.mods.Consensus().Propose(s.SyncInfo()) + if view := s.View(); view == 1 && s.leaderRotation.GetLeader(view) == s.opts.ID() { + s.consensus.Propose(s.SyncInfo()) } } @@ -119,85 +126,79 @@ func (s *Synchronizer) LeafBlock() *hotstuff.Block { // View returns the current view. func (s *Synchronizer) View() hotstuff.View { + s.mut.RLock() + defer s.mut.RUnlock() return s.currentView } -// ViewContext returns a context that is cancelled at the end of the view. -func (s *Synchronizer) ViewContext() context.Context { - return s.viewCtx -} - // SyncInfo returns the highest known QC or TC. func (s *Synchronizer) SyncInfo() hotstuff.SyncInfo { + s.mut.RLock() + defer s.mut.RUnlock() return hotstuff.NewSyncInfo().WithQC(s.highQC).WithTC(s.highTC) } // OnLocalTimeout is called when a local timeout happens. func (s *Synchronizer) OnLocalTimeout() { - // Reset the timer and ctx here so that we can get a new timeout in the same view. - // I think this is necessary to ensure that we can keep sending the same timeout message - // until we get a timeout certificate. - // - // TODO: figure out the best way to handle this context and timeout. - if s.viewCtx.Err() != nil { - s.newCtx(s.duration.Duration()) - } s.timer.Reset(s.duration.Duration()) - if s.lastTimeout != nil && s.lastTimeout.View == s.currentView { - s.mods.Configuration().Timeout(*s.lastTimeout) + view := s.View() + + if s.lastTimeout != nil && s.lastTimeout.View == view { + s.configuration.Timeout(*s.lastTimeout) return } s.duration.ViewTimeout() // increase the duration of the next view - view := s.currentView - s.mods.Logger().Debugf("OnLocalTimeout: %v", view) + s.logger.Debugf("OnLocalTimeout: %v", view) - sig, err := s.mods.Crypto().Sign(view.ToBytes()) + sig, err := s.crypto.Sign(view.ToBytes()) if err != nil { - s.mods.Logger().Warnf("Failed to sign view: %v", err) + s.logger.Warnf("Failed to sign view: %v", err) return } timeoutMsg := hotstuff.TimeoutMsg{ - ID: s.mods.ID(), + ID: s.opts.ID(), View: view, SyncInfo: s.SyncInfo(), ViewSignature: sig, } - if s.mods.Options().ShouldUseAggQC() { + if s.opts.ShouldUseAggQC() { // generate a second signature that will become part of the aggregateQC - sig, err := s.mods.Crypto().Sign(timeoutMsg.ToBytes()) + sig, err := s.crypto.Sign(timeoutMsg.ToBytes()) if err != nil { - s.mods.Logger().Warnf("Failed to sign timeout message: %v", err) + s.logger.Warnf("Failed to sign timeout message: %v", err) return } timeoutMsg.MsgSignature = sig } s.lastTimeout = &timeoutMsg // stop voting for current view - s.mods.Consensus().StopVoting(s.currentView) + s.consensus.StopVoting(view) - s.mods.Configuration().Timeout(timeoutMsg) + s.configuration.Timeout(timeoutMsg) s.OnRemoteTimeout(timeoutMsg) } // OnRemoteTimeout handles an incoming timeout from a remote replica. func (s *Synchronizer) OnRemoteTimeout(timeout hotstuff.TimeoutMsg) { + currView := s.View() + defer func() { // cleanup old timeouts for view := range s.timeouts { - if view < s.currentView { + if view < currView { delete(s.timeouts, view) } } }() - verifier := s.mods.Crypto() + verifier := s.crypto if !verifier.Verify(timeout.ViewSignature, timeout.View.ToBytes()) { return } - s.mods.Logger().Debug("OnRemoteTimeout: ", timeout) + s.logger.Debug("OnRemoteTimeout: ", timeout) s.AdvanceView(timeout.SyncInfo) @@ -211,7 +212,7 @@ func (s *Synchronizer) OnRemoteTimeout(timeout hotstuff.TimeoutMsg) { timeouts[timeout.ID] = timeout } - if len(timeouts) < s.mods.Configuration().QuorumSize() { + if len(timeouts) < s.configuration.QuorumSize() { return } @@ -222,18 +223,18 @@ func (s *Synchronizer) OnRemoteTimeout(timeout hotstuff.TimeoutMsg) { timeoutList = append(timeoutList, t) } - tc, err := s.mods.Crypto().CreateTimeoutCert(timeout.View, timeoutList) + tc, err := s.crypto.CreateTimeoutCert(timeout.View, timeoutList) if err != nil { - s.mods.Logger().Debugf("Failed to create timeout certificate: %v", err) + s.logger.Debugf("Failed to create timeout certificate: %v", err) return } si := s.SyncInfo().WithTC(tc) - if s.mods.Options().ShouldUseAggQC() { - aggQC, err := s.mods.Crypto().CreateAggregateQC(s.currentView, timeoutList) + if s.opts.ShouldUseAggQC() { + aggQC, err := s.crypto.CreateAggregateQC(currView, timeoutList) if err != nil { - s.mods.Logger().Debugf("Failed to create aggregateQC: %v", err) + s.logger.Debugf("Failed to create aggregateQC: %v", err) } else { si = si.WithAggQC(aggQC) } @@ -257,8 +258,8 @@ func (s *Synchronizer) AdvanceView(syncInfo hotstuff.SyncInfo) { // check for a TC if tc, ok := syncInfo.TC(); ok { - if !s.mods.Crypto().VerifyTimeoutCert(tc) { - s.mods.Logger().Info("Timeout Certificate could not be verified!") + if !s.crypto.VerifyTimeoutCert(tc) { + s.logger.Info("Timeout Certificate could not be verified!") return } s.updateHighTC(tc) @@ -273,10 +274,10 @@ func (s *Synchronizer) AdvanceView(syncInfo hotstuff.SyncInfo) { ) // check for an AggQC or QC - if aggQC, haveQC = syncInfo.AggQC(); haveQC && s.mods.Options().ShouldUseAggQC() { - highQC, ok := s.mods.Crypto().VerifyAggregateQC(aggQC) + if aggQC, haveQC = syncInfo.AggQC(); haveQC && s.opts.ShouldUseAggQC() { + highQC, ok := s.crypto.VerifyAggregateQC(aggQC) if !ok { - s.mods.Logger().Info("Aggregated Quorum Certificate could not be verified") + s.logger.Info("Aggregated Quorum Certificate could not be verified") return } if aggQC.View() >= v { @@ -287,8 +288,8 @@ func (s *Synchronizer) AdvanceView(syncInfo hotstuff.SyncInfo) { syncInfo = syncInfo.WithQC(highQC) qc = highQC } else if qc, haveQC = syncInfo.QC(); haveQC { - if !s.mods.Crypto().VerifyQuorumCert(qc) { - s.mods.Logger().Info("Quorum Certificate could not be verified!") + if !s.crypto.VerifyQuorumCert(qc) { + s.logger.Info("Quorum Certificate could not be verified!") return } } @@ -302,7 +303,7 @@ func (s *Synchronizer) AdvanceView(syncInfo hotstuff.SyncInfo) { } } - if v < s.currentView { + if v < s.View() { return } @@ -312,22 +313,25 @@ func (s *Synchronizer) AdvanceView(syncInfo hotstuff.SyncInfo) { s.duration.ViewSucceeded() } - s.currentView = v + 1 + newView := v + 1 + + s.mut.Lock() + s.currentView = newView + s.mut.Unlock() + s.lastTimeout = nil s.duration.ViewStarted() duration := s.duration.Duration() - // cancel the old view context and set up the next one - s.newCtx(duration) s.timer.Reset(duration) - s.mods.Logger().Debugf("advanced to view %d", s.currentView) - s.mods.EventLoop().AddEvent(ViewChangeEvent{View: s.currentView, Timeout: timeout}) + s.logger.Debugf("advanced to view %d", newView) + s.eventLoop.AddEvent(ViewChangeEvent{View: newView, Timeout: timeout}) - leader := s.mods.LeaderRotation().GetLeader(s.currentView) - if leader == s.mods.ID() { - s.mods.Consensus().Propose(syncInfo) - } else if replica, ok := s.mods.Configuration().Replica(leader); ok { + leader := s.leaderRotation.GetLeader(newView) + if leader == s.opts.ID() { + s.consensus.Propose(syncInfo) + } else if replica, ok := s.configuration.Replica(leader); ok { replica.NewView(syncInfo) } } @@ -336,21 +340,21 @@ func (s *Synchronizer) AdvanceView(syncInfo hotstuff.SyncInfo) { // This method is meant to be used instead of the exported UpdateHighQC internally // in this package when the qc has already been verified. func (s *Synchronizer) updateHighQC(qc hotstuff.QuorumCert) { - newBlock, ok := s.mods.BlockChain().Get(qc.BlockHash()) + newBlock, ok := s.blockChain.Get(qc.BlockHash()) if !ok { - s.mods.Logger().Info("updateHighQC: Could not find block referenced by new QC!") + s.logger.Info("updateHighQC: Could not find block referenced by new QC!") return } - oldBlock, ok := s.mods.BlockChain().Get(s.highQC.BlockHash()) + oldBlock, ok := s.blockChain.Get(s.highQC.BlockHash()) if !ok { - s.mods.Logger().Panic("Block from the old highQC missing from chain") + s.logger.Panic("Block from the old highQC missing from chain") } if newBlock.View() > oldBlock.View() { s.highQC = qc s.leafBlock = newBlock - s.mods.Logger().Debug("HighQC updated") + s.logger.Debug("HighQC updated") } } @@ -358,15 +362,10 @@ func (s *Synchronizer) updateHighQC(qc hotstuff.QuorumCert) { func (s *Synchronizer) updateHighTC(tc hotstuff.TimeoutCert) { if tc.View() > s.highTC.View() { s.highTC = tc - s.mods.Logger().Debug("HighTC updated") + s.logger.Debug("HighTC updated") } } -func (s *Synchronizer) newCtx(duration time.Duration) { - s.cancelCtx() - s.viewCtx, s.cancelCtx = context.WithTimeout(context.Background(), duration) -} - var _ modules.Synchronizer = (*Synchronizer)(nil) // ViewChangeEvent is sent on the eventloop whenever a view change occurs. diff --git a/synchronizer/synchronizer_test.go b/synchronizer/synchronizer_test.go index f0e10666..6f8b6780 100644 --- a/synchronizer/synchronizer_test.go +++ b/synchronizer/synchronizer_test.go @@ -1,68 +1,24 @@ package synchronizer_test import ( - "bytes" - "context" - "github.com/relab/hotstuff" - "github.com/relab/hotstuff/modules" "testing" + "github.com/relab/hotstuff" + "github.com/golang/mock/gomock" "github.com/relab/hotstuff/internal/mocks" "github.com/relab/hotstuff/internal/testutil" + "github.com/relab/hotstuff/modules" . "github.com/relab/hotstuff/synchronizer" ) -func TestLocalTimeout(t *testing.T) { - ctrl := gomock.NewController(t) - qc := hotstuff.NewQuorumCert(nil, 0, hotstuff.GetGenesis().Hash()) - key := testutil.GenerateECDSAKey(t) - builder := modules.NewConsensusBuilder(2, key) - testutil.TestModules(t, ctrl, 2, key, &builder) - hs := mocks.NewMockConsensus(ctrl) - s := New(testutil.FixedTimeout(10)) - builder.Register(hs, s) - mods := builder.Build() - cfg := mods.Configuration().(*mocks.MockConfiguration) - leader := testutil.CreateMockReplica(t, ctrl, 1, testutil.GenerateECDSAKey(t)) - testutil.ConfigAddReplica(t, cfg, leader) - - c := make(chan struct{}) - hs.EXPECT().StopVoting(hotstuff.View(1)).AnyTimes() - cfg. - EXPECT(). - Timeout(gomock.AssignableToTypeOf(hotstuff.TimeoutMsg{})). - Do(func(msg hotstuff.TimeoutMsg) { - if msg.View != 1 { - t.Errorf("wrong view. got: %v, want: %v", msg.View, 1) - } - if msg.ID != 2 { - t.Errorf("wrong ID. got: %v, want: %v", msg.ID, 2) - } - if msgQC, ok := msg.SyncInfo.QC(); ok && !bytes.Equal(msgQC.ToBytes(), qc.ToBytes()) { - t.Errorf("wrong QC. got: %v, want: %v", msgQC, qc) - } - if !mods.Crypto().Verify(msg.ViewSignature, msg.View.ToBytes()) { - t.Error("failed to verify signature") - } - c <- struct{}{} - }).AnyTimes() - ctx, cancel := context.WithCancel(context.Background()) - go func() { - mods.Synchronizer().Start(ctx) - mods.Run(ctx) - }() - <-c - cancel() -} - func TestAdvanceViewQC(t *testing.T) { const n = 4 ctrl := gomock.NewController(t) builders := testutil.CreateBuilders(t, ctrl, n) s := New(testutil.FixedTimeout(1000)) hs := mocks.NewMockConsensus(ctrl) - builders[0].Register(s, hs) + builders[0].Add(s, hs) hl := builders.Build() signers := hl.Signers() @@ -74,7 +30,11 @@ func TestAdvanceViewQC(t *testing.T) { 1, 2, ) - hl[0].BlockChain().Store(block) + + var blockChain modules.BlockChain + hl[0].Get(&blockChain) + + blockChain.Store(block) qc := testutil.CreateQC(t, block, signers) // synchronizer should tell hotstuff to propose hs.EXPECT().Propose(gomock.AssignableToTypeOf(hotstuff.NewSyncInfo())) @@ -92,7 +52,7 @@ func TestAdvanceViewTC(t *testing.T) { builders := testutil.CreateBuilders(t, ctrl, n) s := New(testutil.FixedTimeout(100)) hs := mocks.NewMockConsensus(ctrl) - builders[0].Register(s, hs) + builders[0].Add(s, hs) hl := builders.Build() signers := hl.Signers() @@ -108,28 +68,3 @@ func TestAdvanceViewTC(t *testing.T) { t.Errorf("wrong view: expected: %v, got: %v", 2, s.View()) } } - -// func TestRemoteTimeout(t *testing.T) { -// const n = 4 -// ctrl := gomock.NewController(t) -// builders := testutil.CreateBuilders(t, ctrl, n) -// s := New(testutil.FixedTimeout(100)) -// hs := mocks.NewMockConsensus(ctrl) -// builders[0].Register(s, hs) - -// hl := builders.Build() -// signers := hl.Signers() - -// timeouts := testutil.CreateTimeouts(t, 1, signers[1:]) - -// // synchronizer should tell hotstuff to propose -// hs.EXPECT().Propose(gomock.AssignableToTypeOf(consensus.NewSyncInfo())) - -// for _, timeout := range timeouts { -// s.OnRemoteTimeout(timeout) -// } - -// if s.View() != 2 { -// t.Errorf("wrong view: expected: %v, got: %v", 2, s.View()) -// } -// } diff --git a/synchronizer/viewduration.go b/synchronizer/viewduration.go index 9cce4710..2b9d15a9 100644 --- a/synchronizer/viewduration.go +++ b/synchronizer/viewduration.go @@ -1,7 +1,6 @@ package synchronizer import ( - "github.com/relab/hotstuff/modules" "math" "time" ) @@ -35,7 +34,6 @@ func NewViewDuration(sampleSize uint64, startTimeout, maxTimeout, multiplier flo // viewDuration uses statistics from previous views to guess a good value for the view duration. // It only takes a limited amount of measurements into account. type viewDuration struct { - mods *modules.ConsensusCore mul float64 // on failed views, multiply the current mean by this number (should be > 1) limit uint64 // how many measurements should be included in mean count uint64 // total number of measurements @@ -46,12 +44,6 @@ type viewDuration struct { max float64 // upper bound on view timeout } -// InitModule gives the module a reference to the ConsensusCore object. -// It also allows the module to set module options using the OptionsBuilder. -func (v *viewDuration) InitModule(mods *modules.ConsensusCore, _ *modules.OptionsBuilder) { - v.mods = mods -} - // ViewSucceeded calculates the duration of the view // and updates the internal values used for mean and variance calculations. func (v *viewDuration) ViewSucceeded() { @@ -116,8 +108,5 @@ func (v *viewDuration) Duration() time.Duration { duration = v.max } - if uint64(v.mods.Synchronizer().View())%v.limit == 0 { - v.mods.Logger().Infof("Mean: %.2fms, Dev: %.2f, Timeout: %.2fms (last %d views)", v.mean, dev, duration, v.limit) - } return time.Duration(duration * float64(time.Millisecond)) } diff --git a/twins/fhsbug_test.go b/twins/fhsbug_test.go index 225f77e8..beb7de1c 100644 --- a/twins/fhsbug_test.go +++ b/twins/fhsbug_test.go @@ -135,15 +135,18 @@ func TestFHSBug(t *testing.T) { // A wrapper around the FHS rules that swaps the commit rule for a vulnerable version type vulnerableFHS struct { - mods *modules.ConsensusCore - inner fasthotstuff.FastHotStuff + logger logging.Logger + blockChain modules.BlockChain + inner fasthotstuff.FastHotStuff } -// InitModule gives the module a reference to the Modules object. -// It also allows the module to set module options using the OptionsBuilder. -func (fhs *vulnerableFHS) InitModule(mods *modules.ConsensusCore, opts *modules.OptionsBuilder) { - fhs.mods = mods - fhs.inner.InitModule(mods, opts) +func (fhs *vulnerableFHS) InitModule(mods *modules.Core) { + mods.GetAll( + &fhs.logger, + &fhs.blockChain, + ) + + fhs.inner.InitModule(mods) } // VoteRule decides whether to vote for the block. @@ -155,7 +158,7 @@ func (fhs *vulnerableFHS) qcRef(qc hotstuff.QuorumCert) (*hotstuff.Block, bool) if (hotstuff.Hash{}) == qc.BlockHash() { return nil, false } - return fhs.mods.BlockChain().Get(qc.BlockHash()) + return fhs.blockChain.Get(qc.BlockHash()) } // CommitRule decides whether an ancestor of the block can be committed. @@ -164,7 +167,7 @@ func (fhs *vulnerableFHS) CommitRule(block *hotstuff.Block) *hotstuff.Block { if !ok { return nil } - fhs.mods.Logger().Debug("PRECOMMIT: ", parent) + fhs.logger.Debug("PRECOMMIT: ", parent) grandparent, ok := fhs.qcRef(parent.QuorumCert()) if !ok { return nil @@ -172,7 +175,7 @@ func (fhs *vulnerableFHS) CommitRule(block *hotstuff.Block) *hotstuff.Block { // NOTE: this does check for a direct link between the block and the grandparent. // This is what causes the safety violation. if block.Parent() == parent.Hash() && parent.Parent() == grandparent.Hash() { - fhs.mods.Logger().Debug("COMMIT(vulnerable): ", grandparent) + fhs.logger.Debug("COMMIT(vulnerable): ", grandparent) return grandparent } return nil diff --git a/twins/network.go b/twins/network.go index 5f328608..64176dd9 100644 --- a/twins/network.go +++ b/twins/network.go @@ -14,6 +14,7 @@ import ( "github.com/relab/hotstuff/crypto" "github.com/relab/hotstuff/crypto/ecdsa" "github.com/relab/hotstuff/crypto/keygen" + "github.com/relab/hotstuff/eventloop" "github.com/relab/hotstuff/logging" "github.com/relab/hotstuff/modules" "github.com/relab/hotstuff/synchronizer" @@ -34,13 +35,30 @@ func (id NodeID) String() string { } type node struct { + blockChain modules.BlockChain + consensus modules.Consensus + eventLoop *eventloop.EventLoop + leaderRotation modules.LeaderRotation + synchronizer modules.Synchronizer + opts *modules.Options + id NodeID - mods *modules.ConsensusCore executedBlocks []*hotstuff.Block effectiveView hotstuff.View log strings.Builder } +func (n *node) InitModule(mods *modules.Core) { + mods.GetAll( + &n.blockChain, + &n.consensus, + &n.eventLoop, + &n.leaderRotation, + &n.synchronizer, + &n.opts, + ) +} + type pendingMessage struct { message any receiver uint32 @@ -89,16 +107,16 @@ func NewPartitionedNetwork(views []View, dropTypes ...any) *Network { return n } -// GetNodeBuilder returns a consensus.ConsensusBuilder instance for a node in the network. -func (n *Network) GetNodeBuilder(id NodeID, pk hotstuff.PrivateKey) modules.ConsensusBuilder { +// GetNodeBuilder returns a consensus.Builder instance for a node in the network. +func (n *Network) GetNodeBuilder(id NodeID, pk hotstuff.PrivateKey) modules.Builder { node := node{ id: id, } n.nodes[id.NetworkID] = &node n.replicas[id.ReplicaID] = append(n.replicas[id.ReplicaID], &node) - builder := modules.NewConsensusBuilder(id.ReplicaID, pk) + builder := modules.NewBuilder(id.ReplicaID, pk) // register node as an anonymous module because that allows configuration to obtain it. - builder.Register(&node) + builder.Add(&node) return builder } @@ -119,7 +137,8 @@ func (n *Network) createTwinsNodes(nodes []NodeID, scenario Scenario, consensusN if !ok { return fmt.Errorf("unknown consensus module: '%s'", consensusName) } - builder.Register( + builder.Add( + eventloop.New(100), blockchain.New(), consensus.New(consensusModule), consensus.NewVotingMachine(), @@ -128,12 +147,12 @@ func (n *Network) createTwinsNodes(nodes []NodeID, scenario Scenario, consensusN logging.NewWithDest(&node.log, fmt.Sprintf("r%dn%d", nodeID.ReplicaID, nodeID.NetworkID)), // twins-specific: &configuration{network: n, node: node}, - leaderRotation(n.views), - commandModule{commandGenerator: cg, node: node}, &timeoutManager{network: n, node: node, timeout: 5}, + leaderRotation(n.views), + &commandModule{commandGenerator: cg, node: node}, ) - builder.OptionsBuilder().SetShouldVerifyVotesSync() - node.mods = builder.Build() + builder.Options().SetShouldVerifyVotesSync() + builder.Build() } return nil } @@ -141,8 +160,8 @@ func (n *Network) createTwinsNodes(nodes []NodeID, scenario Scenario, consensusN func (n *Network) run(ticks int) { // kick off the initial proposal(s) for _, node := range n.nodes { - if node.mods.LeaderRotation().GetLeader(1) == node.id.ReplicaID { - node.mods.Consensus().Propose(node.mods.Synchronizer().(*synchronizer.Synchronizer).SyncInfo()) + if node.leaderRotation.GetLeader(1) == node.id.ReplicaID { + node.consensus.Propose(node.synchronizer.(*synchronizer.Synchronizer).SyncInfo()) } } @@ -154,14 +173,14 @@ func (n *Network) run(ticks int) { // tick performs one tick for each node func (n *Network) tick() { for _, msg := range n.pendingMessages { - n.nodes[msg.receiver].mods.EventLoop().AddEvent(msg.message) + n.nodes[msg.receiver].eventLoop.AddEvent(msg.message) } n.pendingMessages = nil for _, node := range n.nodes { - node.mods.EventLoop().AddEvent(tick{}) + node.eventLoop.AddEvent(tick{}) // run each event loop as long as it has events - for node.mods.EventLoop().Tick() { + for node.eventLoop.Tick(context.Background()) { } } } @@ -176,10 +195,10 @@ func (n *Network) shouldDrop(sender, receiver uint32, message any) bool { // Index into viewPartitions. i := -1 - if node.effectiveView > node.mods.Synchronizer().View() { + if node.effectiveView > node.synchronizer.View() { i += int(node.effectiveView) } else { - i += int(node.mods.Synchronizer().View()) + i += int(node.synchronizer.View()) } if i < 0 { @@ -215,10 +234,9 @@ type configuration struct { } // alternative way to get a pointer to the node. -func (c *configuration) InitModule(mods *modules.ConsensusCore, _ *modules.OptionsBuilder) { +func (c *configuration) InitModule(mods *modules.Core) { if c.node == nil { - mods.GetModuleByType(&c.node) - c.node.mods = mods + mods.TryGet(&c.node) } } @@ -323,7 +341,7 @@ func (c *configuration) Fetch(_ context.Context, hash hotstuff.Hash) (block *hot if c.shouldDrop(node.id, hash) { continue } - block, ok = node.mods.BlockChain().LocalGet(hash) + block, ok = node.blockChain.LocalGet(hash) if ok { return block, true } @@ -346,13 +364,13 @@ func (r *replica) ID() hotstuff.ID { // PublicKey returns the replica's public key. func (r *replica) PublicKey() hotstuff.PublicKey { - return r.config.network.replicas[r.id][0].mods.PrivateKey().Public() + return r.config.network.replicas[r.id][0].opts.PrivateKey().Public() } // Vote sends the partial certificate to the other replica. func (r *replica) Vote(cert hotstuff.PartialCert) { r.config.sendMessage(r.id, hotstuff.VoteMsg{ - ID: r.config.node.mods.ID(), + ID: r.config.node.opts.ID(), PartialCert: cert, }) } @@ -360,13 +378,13 @@ func (r *replica) Vote(cert hotstuff.PartialCert) { // NewView sends the quorum certificate to the other replica. func (r *replica) NewView(si hotstuff.SyncInfo) { r.config.sendMessage(r.id, hotstuff.NewViewMsg{ - ID: r.config.node.mods.ID(), + ID: r.config.node.opts.ID(), SyncInfo: si, }) } func (r *replica) Metadata() map[string]string { - return r.config.network.replicas[r.id][0].mods.Options().ConnectionMetadata() + return r.config.network.replicas[r.id][0].opts.ConnectionMetadata() } // NodeSet is a set of network ids. @@ -409,7 +427,9 @@ func (s *NodeSet) UnmarshalJSON(data []byte) error { type tick struct{} type timeoutManager struct { - mods *modules.ConsensusCore + synchronizer modules.Synchronizer + eventLoop *eventloop.EventLoop + node *node network *Network countdown int @@ -419,8 +439,8 @@ type timeoutManager struct { func (tm *timeoutManager) advance() { tm.countdown-- if tm.countdown == 0 { - view := tm.mods.Synchronizer().View() - tm.mods.EventLoop().AddEvent(synchronizer.TimeoutEvent{View: view}) + view := tm.synchronizer.View() + tm.eventLoop.AddEvent(synchronizer.TimeoutEvent{View: view}) tm.countdown = tm.timeout if tm.node.effectiveView <= view { tm.node.effectiveView = view + 1 @@ -440,12 +460,16 @@ func (tm *timeoutManager) viewChange(event synchronizer.ViewChangeEvent) { // InitModule gives the module a reference to the Modules object. // It also allows the module to set module options using the OptionsBuilder. -func (tm *timeoutManager) InitModule(mods *modules.ConsensusCore, _ *modules.OptionsBuilder) { - tm.mods = mods - tm.mods.EventLoop().RegisterObserver(tick{}, func(event any) { +func (tm *timeoutManager) InitModule(mods *modules.Core) { + mods.GetAll( + &tm.synchronizer, + &tm.eventLoop, + ) + + tm.eventLoop.RegisterObserver(tick{}, func(event any) { tm.advance() }) - tm.mods.EventLoop().RegisterObserver(synchronizer.ViewChangeEvent{}, func(event any) { + tm.eventLoop.RegisterObserver(synchronizer.ViewChangeEvent{}, func(event any) { tm.viewChange(event.(synchronizer.ViewChangeEvent)) }) } diff --git a/types.go b/types.go index 51cbfe8c..3e3551b2 100644 --- a/types.go +++ b/types.go @@ -9,6 +9,8 @@ import ( "io" "strconv" "strings" + + "github.com/relab/hotstuff/util" ) // IDSet implements a set of replica IDs. It is used to show which replicas participated in some event. @@ -101,7 +103,7 @@ func (h Hash) String() string { // Command is a client request to be executed by the consensus protocol. // // The string type is used because it is immutable and can hold arbitrary bytes of any length. -type Command string +type Command = string // ToBytes is an object that can be converted into bytes for the purposes of hashing, etc. type ToBytes interface { @@ -256,6 +258,16 @@ func NewQuorumCert(signature QuorumSignature, view View, hash Hash) QuorumCert { return QuorumCert{signature, view, hash} } +// WriteTo writes the quorum certificate to the writer. +func (qc QuorumCert) WriteTo(writer io.Writer) (n int64, err error) { + return util.WriteAllTo( + writer, + qc.view.ToBytes(), + qc.hash[:], + qc.signature, + ) +} + // ToBytes returns a byte representation of the quorum certificate. func (qc QuorumCert) ToBytes() []byte { b := qc.view.ToBytes() diff --git a/util/gpool/gpool.go b/util/gpool/gpool.go new file mode 100644 index 00000000..7f46db03 --- /dev/null +++ b/util/gpool/gpool.go @@ -0,0 +1,34 @@ +// Package gpool provides a generic sync.Pool. +package gpool + +import "sync" + +// Pool is a generic sync.Pool. +type Pool[T any] sync.Pool + +// New returns an initialized generic sync.Pool. +func New[T any](newFunc func() T) Pool[T] { + if newFunc != nil { + return Pool[T](sync.Pool{ + New: func() any { return newFunc() }, + }) + } + return Pool[T]{} +} + +// Get retrieves a resource from the pool. +// Returns the zero value of T if no resource is available and no New func is specified. +func (p *Pool[T]) Get() (val T) { + sp := (*sync.Pool)(p) + v := sp.Get() + if v != nil { + return v.(T) + } + return val +} + +// Put puts the resource into the pool. +func (p *Pool[T]) Put(val T) { + sp := (*sync.Pool)(p) + sp.Put(val) +} diff --git a/util/io.go b/util/io.go new file mode 100644 index 00000000..8a4cee57 --- /dev/null +++ b/util/io.go @@ -0,0 +1,58 @@ +package util + +import ( + "fmt" + "io" + "reflect" + "unsafe" +) + +type toBytes interface { + ToBytes() []byte +} + +type bytes interface { + Bytes() []byte +} + +// WriteAllTo writes all the data to the writer. +func WriteAllTo(writer io.Writer, data ...any) (n int64, err error) { + for _, d := range data { + var ( + nn int64 + nnn int + ) + switch d := d.(type) { + case io.WriterTo: + nn, err = d.WriteTo(writer) + case string: + nnn, err = writer.Write(unsafeStringToBytes(d)) + case []byte: + nnn, err = writer.Write(d) + case toBytes: + nnn, err = writer.Write(d.ToBytes()) + case bytes: + nnn, err = writer.Write(d.Bytes()) + case nil: + default: + panic(fmt.Sprintf("cannot write %T", d)) + } + nn += int64(nnn) + n += int64(nn) + if err != nil { + return n, err + } + } + return n, nil +} + +func unsafeStringToBytes(s string) []byte { + if s == "" { + return []byte{} + } + const max = 0x7fff0000 + if len(s) > max { + panic("string too long") + } + return (*[max]byte)(unsafe.Pointer((*reflect.StringHeader)(unsafe.Pointer(&s)).Data))[:len(s):len(s)] +}