diff --git a/rpc.go b/rpc.go index b2bb900..7a9f0e3 100644 --- a/rpc.go +++ b/rpc.go @@ -56,6 +56,13 @@ func responseTopic(base string, pid peer.ID) string { return path.Join(base, pid.String(), "_response") } +type ongoingMessage struct { + ctx context.Context + data []byte + opts []pubsub.PubOpt + respCh chan internalResponse +} + // Topic provides a nice interface to a libp2p pubsub topic. type Topic struct { ps *pubsub.PubSub @@ -63,7 +70,7 @@ type Topic struct { eventHandler EventHandler messageHandler MessageHandler - resChs map[cid.Cid]chan internalResponse + ongoing map[cid.Cid]ongoingMessage resTopic *Topic t *pubsub.Topic @@ -88,7 +95,6 @@ func NewTopic(ctx context.Context, ps *pubsub.PubSub, host peer.ID, topic string } t.resTopic.eventHandler = t.resEventHandler t.resTopic.messageHandler = t.resMessageHandler - t.resChs = make(map[cid.Cid]chan internalResponse) return t, nil } @@ -112,11 +118,12 @@ func newTopic(ctx context.Context, ps *pubsub.PubSub, host peer.ID, topic string } t := &Topic{ - ps: ps, - host: host, - t: top, - h: handler, - s: sub, + ps: ps, + host: host, + t: top, + h: handler, + s: sub, + ongoing: make(map[cid.Cid]ongoingMessage), } t.ctx, t.cancel = context.WithCancel(ctx) @@ -161,7 +168,7 @@ func (t *Topic) SetMessageHandler(handler MessageHandler) { t.messageHandler = handler } -// Publish data. See PublishOptions for option details. +// Publish data. Note that the data may arrive peers duplicated. See PublishOptions for option details. func (t *Topic) Publish( ctx context.Context, data []byte, @@ -175,57 +182,59 @@ func (t *Topic) Publish( } var respCh chan internalResponse - var msgID cid.Cid + msgID := cid.NewCidV1(cid.Raw, util.Hash(data)) if !args.ignoreResponse { - msgID = cid.NewCidV1(cid.Raw, util.Hash(data)) respCh = make(chan internalResponse) - t.lk.Lock() - t.resChs[msgID] = respCh - t.lk.Unlock() } + t.lk.Lock() + t.ongoing[msgID] = ongoingMessage{ + ctx: ctx, + data: data, + opts: args.pubOpts, + respCh: respCh, + } + t.lk.Unlock() if err := t.t.Publish(ctx, data, args.pubOpts...); err != nil { - return nil, fmt.Errorf("publishing to main topic: %v", err) + return nil, fmt.Errorf("publishing to topic: %v", err) } resultCh := make(chan Response) - if respCh != nil { - go func() { - defer func() { - t.lk.Lock() - delete(t.resChs, msgID) - t.lk.Unlock() - close(resultCh) - }() - for { - select { - case <-ctx.Done(): - if !args.multiResponse { - resultCh <- Response{ - Err: ErrResponseNotReceived, - } - } + go func() { + defer func() { + t.lk.Lock() + delete(t.ongoing, msgID) + t.lk.Unlock() + close(resultCh) + }() + for { + select { + case <-ctx.Done(): + if args.ignoreResponse { return - case r := <-respCh: - res := Response{ - ID: r.ID, - From: peer.ID(r.From), - Data: r.Data, - } - if r.Err != "" { - res.Err = errors.New(r.Err) - } + } + if !args.multiResponse { + resultCh <- Response{Err: ErrResponseNotReceived} + } + return + case r := <-respCh: + res := Response{ + ID: r.ID, + From: peer.ID(r.From), + Data: r.Data, + } + if r.Err != "" { + res.Err = errors.New(r.Err) + } + if !args.ignoreResponse { resultCh <- res - if !args.multiResponse { - return - } + } + if !args.multiResponse { + return } } - }() - } else { - close(resultCh) - } - + } + }() return resultCh, nil } @@ -239,6 +248,11 @@ func (t *Topic) watch() { switch e.Type { case pubsub.PeerJoin: msg = "JOINED" + // Note: it looks like we are publishing to this + // specific peer, but the rpc library doesn't have the + // ability, so it actually does is to republish to all + // peers. + t.republishTo(e.Peer) case pubsub.PeerLeave: msg = "LEFT" default: @@ -252,6 +266,19 @@ func (t *Topic) watch() { } } +func (t *Topic) republishTo(p peer.ID) { + t.lk.Lock() + for _, m := range t.ongoing { + go func(m ongoingMessage) { + log.Debugf("republishing %s because peer %s newly joins", t.t, p) + if err := t.t.Publish(m.ctx, m.data, m.opts...); err != nil { + log.Errorf("republishing to topic: %v", err) + } + }(m) + } + t.lk.Unlock() +} + func (t *Topic) listen() { for { msg, err := t.s.Next(t.ctx) @@ -328,12 +355,14 @@ func (t *Topic) resMessageHandler(from peer.ID, topic string, msg []byte) ([]byt log.Debugf("%s response from %s: %s", topic, from, res.ID) t.lk.Lock() - ch := t.resChs[id] + m, exists := t.ongoing[id] t.lk.Unlock() - if ch != nil { - ch <- res + if exists { + if m.respCh != nil { + m.respCh <- res + } } else { - log.Warnf("%s missed response from %s: %s", topic, from, res.ID) + log.Debugf("%s response from %s arrives too late, discarding", topic, from) } return nil, nil // no response to a response } diff --git a/rpc_test.go b/rpc_test.go index cf46f7a..05f5542 100644 --- a/rpc_test.go +++ b/rpc_test.go @@ -2,6 +2,7 @@ package rpc_test import ( "context" + "sync" "testing" "time" @@ -46,7 +47,7 @@ func TestPingPong(t *testing.T) { eventHandler := func(from core.ID, topic string, msg []byte) { t.Logf("%s event: %s %s", topic, from, msg) } - messageHandler := func(from core.ID, topic string, msg []byte) ([]byte, error) { + messageHandler := func(from core.ID, topic string, msg []byte) ([]byte, error) { // nolint:unparam t.Logf("%s message: %s %s", topic, from, msg) return []byte("pong"), nil } @@ -85,7 +86,43 @@ func TestPingPong(t *testing.T) { assert.NotEmpty(t, r2.ID) assert.Equal(t, p1.Host().ID().String(), r2.From.String()) + // test ignore response - make sure nothing weird happens. + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + _, err = t1.Publish(ctx, []byte("ping"), rpc.WithIgnoreResponse(true)) + require.NoError(t, err) + cancel() + + // test retries; peer1 requests "pong" from peer2, but peer2 joins topic after the request + t3, err := p1.NewTopic(context.Background(), "topic2", true) + require.NoError(t, err) + t3.SetEventHandler(eventHandler) + t3.SetMessageHandler(messageHandler) + fin.Add(t3) + + lk := sync.Mutex{} + go func() { + time.Sleep(time.Second) // wait until after peer1 publishes the request + + t4, err := p2.NewTopic(context.Background(), "topic2", true) + require.NoError(t, err) + t4.SetEventHandler(eventHandler) + t4.SetMessageHandler(messageHandler) + lk.Lock() + fin.Add(t4) + lk.Unlock() + }() + + // allow enough time for peer2 join event to be propagated. + ctx, cancel = context.WithTimeout(context.Background(), time.Second*2) + defer cancel() + rc3, err := t3.Publish(ctx, []byte("ping")) + require.NoError(t, err) + r3 := <-rc3 + require.NoError(t, r3.Err) + + lk.Lock() require.NoError(t, fin.Cleanup(nil)) + lk.Unlock() } func TestMultiPingPong(t *testing.T) { @@ -115,7 +152,7 @@ func TestMultiPingPong(t *testing.T) { eventHandler := func(from core.ID, topic string, msg []byte) { t.Logf("%s event: %s %s", topic, from, msg) } - messageHandler := func(from core.ID, topic string, msg []byte) ([]byte, error) { + messageHandler := func(from core.ID, topic string, msg []byte) ([]byte, error) { // nolint:unparam t.Logf("%s message: %s %s", topic, from, msg) return []byte("pong"), nil } @@ -155,7 +192,55 @@ func TestMultiPingPong(t *testing.T) { } assert.Len(t, pongs, 2) + // test retries; peer1 requests "pong" from peer2 and peer3, but peer2 and peer3 join topic after the request + t4, err := p1.NewTopic(context.Background(), "topic2", true) + require.NoError(t, err) + t4.SetEventHandler(eventHandler) + t4.SetMessageHandler(messageHandler) + fin.Add(t4) + + var lk sync.Mutex + go func() { + time.Sleep(time.Second) // wait until after peer1 publishes the request + + t5, err := p2.NewTopic(context.Background(), "topic2", true) + require.NoError(t, err) + t5.SetEventHandler(eventHandler) + t5.SetMessageHandler(messageHandler) + lk.Lock() + fin.Add(t5) + lk.Unlock() + + t6, err := p3.NewTopic(context.Background(), "topic2", true) + require.NoError(t, err) + t6.SetEventHandler(eventHandler) + t6.SetMessageHandler(messageHandler) + lk.Lock() + fin.Add(t6) + lk.Unlock() + }() + // allow enough time for peer2 join event to be propagated. + ctx2, cancel2 := context.WithTimeout(context.Background(), time.Second*2) + defer cancel2() + rc2, err := t4.Publish( + ctx2, + []byte("ping"), + rpc.WithMultiResponse(true), + ) + require.NoError(t, err) + var pongs2 []struct{} + for r := range rc2 { + require.NotNil(t, r) + require.NoError(t, r.Err) + assert.Equal(t, "pong", string(r.Data)) + assert.NotEmpty(t, r.ID) + pongs2 = append(pongs2, struct{}{}) + } + assert.True(t, len(pongs2) >= 2, "at least 2 responses should have been received") + + lk.Lock() require.NoError(t, fin.Cleanup(nil)) + lk.Unlock() } func setLogLevels(systems map[string]logging.LogLevel) error {