diff --git a/client.go b/client.go index f7b49d1..ac55e34 100644 --- a/client.go +++ b/client.go @@ -163,7 +163,6 @@ type client struct { lastID int64 backoffFactory func() backoff.BackOff cancelFunc context.CancelFunc - wg sync.WaitGroup } func (c *client) Start() { diff --git a/client_test.go b/client_test.go index 8a50657..4589113 100644 --- a/client_test.go +++ b/client_test.go @@ -213,7 +213,7 @@ var _ = Describe("Client", func() { panicableInfo.shouldPanic.Store(true) panicableDebug.shouldPanic.Store(true) // Ensure that we really don't get any logs anymore - time.Sleep(1 * time.Second) + time.Sleep(500 * time.Millisecond) Expect(clientConn.State()).To(BeEquivalentTo(ClientClosed)) server.cancel() close(done) diff --git a/loop.go b/loop.go index bfb128b..02b16b5 100644 --- a/loop.go +++ b/loop.go @@ -50,7 +50,10 @@ func (l *loop) Run(connected chan struct{}) (err error) { close(connected) // Process messages ch := make(chan receiveResult, 1) + wg := l.party.waitGroup() + wg.Add(1) go func() { + defer wg.Done() recv := l.hubConn.Receive() loop: for { diff --git a/party.go b/party.go index b2f4323..174b2c7 100644 --- a/party.go +++ b/party.go @@ -2,6 +2,7 @@ package signalr import ( "context" + "sync" "time" "github.com/go-kit/log" @@ -29,7 +30,7 @@ type Party interface { insecureSkipVerify() bool setInsecureSkipVerify(skip bool) - originPatterns() [] string + originPatterns() []string setOriginPatterns(orgs []string) chanReceiveTimeout() time.Duration @@ -50,6 +51,8 @@ type Party interface { maximumReceiveMessageSize() uint setMaximumReceiveMessageSize(size uint) + + waitGroup() *sync.WaitGroup } func newPartyBase(parentContext context.Context, info log.Logger, dbg log.Logger) partyBase { @@ -81,10 +84,11 @@ type partyBase struct { _streamBufferCapacity uint _maximumReceiveMessageSize uint _enableDetailedErrors bool - _insecureSkipVerify bool - _originPatterns []string + _insecureSkipVerify bool + _originPatterns []string info StructuredLogger dbg StructuredLogger + wg sync.WaitGroup } func (p *partyBase) context() context.Context { @@ -120,16 +124,16 @@ func (p *partyBase) setKeepAliveInterval(interval time.Duration) { } func (p *partyBase) insecureSkipVerify() bool { - return p._insecureSkipVerify + return p._insecureSkipVerify } func (p *partyBase) setInsecureSkipVerify(skip bool) { p._insecureSkipVerify = skip } func (p *partyBase) originPatterns() []string { - return p._originPatterns + return p._originPatterns } -func (p *partyBase) setOriginPatterns(origins []string) { +func (p *partyBase) setOriginPatterns(origins []string) { p._originPatterns = origins } @@ -173,3 +177,7 @@ func (p *partyBase) setLoggers(info StructuredLogger, dbg StructuredLogger) { func (p *partyBase) loggers() (info StructuredLogger, debug StructuredLogger) { return p.info, p.dbg } + +func (p *partyBase) waitGroup() *sync.WaitGroup { + return &p.wg +}