Skip to content

Commit

Permalink
fix(go-zookeeper#36): reset session after client disconnected
Browse files Browse the repository at this point in the history
Signed-off-by: zwtop <[email protected]>
  • Loading branch information
Colin McIntosh authored and zwtop committed Jun 6, 2023
1 parent abd6db4 commit da86d29
Showing 1 changed file with 21 additions and 5 deletions.
26 changes: 21 additions & 5 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,8 @@ func (c *Conn) sendRequest(
}

func (c *Conn) loop(ctx context.Context) {
disconnectTime := time.Time{}

for {
if err := c.connect(); err != nil {
// c.Close() was called
Expand All @@ -433,15 +435,19 @@ func (c *Conn) loop(ctx context.Context) {
err := c.authenticate()
switch {
case err == ErrSessionExpired:
c.logger.Printf("authentication failed: %s", err)
c.invalidateWatches(err)
c.logger.Printf("authentication expired: %s", err)
c.resetSession(err)
case err != nil && c.conn != nil:
c.logger.Printf("authentication failed: %s", err)
c.conn.Close()
if err == io.EOF && !disconnectTime.IsZero() && c.sessionExpired(disconnectTime) {
c.resetSession(err)
}
case err == nil:
if c.logInfo {
c.logger.Printf("authenticated: id=%d, timeout=%d", c.SessionID(), c.sessionTimeoutMs)
}
disconnectTime = time.Time{} // reset disconnect time
c.hostProvider.Connected() // mark success
c.closeChan = make(chan struct{}) // channel to tell send loop stop

Expand Down Expand Up @@ -483,6 +489,7 @@ func (c *Conn) loop(ctx context.Context) {

c.sendSetWatches()
wg.Wait()
disconnectTime = time.Now()
}

c.setState(StateDisconnected)
Expand Down Expand Up @@ -703,9 +710,6 @@ func (c *Conn) authenticate() error {
return err
}
if r.SessionID == 0 {
atomic.StoreInt64(&c.sessionID, int64(0))
c.passwd = emptyPassword
c.lastZxid = 0
c.setState(StateExpired)
return ErrSessionExpired
}
Expand All @@ -718,6 +722,18 @@ func (c *Conn) authenticate() error {
return nil
}

func (c *Conn) sessionExpired(d time.Time) bool {
return d.Add(time.Duration(c.sessionTimeoutMs) * time.Millisecond).Before(time.Now())
}

func (c *Conn) resetSession(err error) {
c.logger.Printf("session reset on error: %s", err)
atomic.StoreInt64(&c.sessionID, int64(0))
c.passwd = emptyPassword
c.lastZxid = 0
c.invalidateWatches(err)
}

func (c *Conn) sendData(req *request) error {
header := &requestHeader{req.xid, req.opcode}
n, err := encodePacket(c.buf[4:], header)
Expand Down

0 comments on commit da86d29

Please sign in to comment.