Skip to content

Commit

Permalink
Refactor the way we check for expired sessions
Browse files Browse the repository at this point in the history
  • Loading branch information
flimzy committed Jun 8, 2021
1 parent 7c3aefb commit 68be90d
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 114 deletions.
45 changes: 18 additions & 27 deletions chttp/cookieauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ type CookieAuth struct {
client *Client
// transport stores the original transport that is overridden by this auth
// mechanism
transport http.RoundTripper
transport http.RoundTripper
authExpiry time.Time
}

var _ Authenticator = &CookieAuth{}
Expand All @@ -48,26 +49,6 @@ func (a *CookieAuth) Authenticate(c *Client) error {
return nil
}

// shouldAuth returns true if there is no cookie set, or if it has expired.
func (a *CookieAuth) shouldAuth(req *http.Request) bool {
if _, err := req.Cookie(kivik.SessionCookieName); err == nil {
return false
}
cookie := a.Cookie()
if cookie == nil {
return true
}
if !cookie.Expires.IsZero() {
return cookie.Expires.Before(time.Now().Add(time.Minute))
}
// If we get here, it means the server did not include an expiry time in
// the session cookie. Some CouchDB configurations do this, but rather than
// re-authenticating for every request, we'll let the session expire. A
// future change might be to make a client-configurable option to set the
// re-authentication timeout.
return false
}

// Cookie returns the current session cookie if found, or nil if not.
func (a *CookieAuth) Cookie() *http.Cookie {
if a.client == nil {
Expand Down Expand Up @@ -96,6 +77,14 @@ func (a *CookieAuth) RoundTrip(req *http.Request) (*http.Response, error) {
if err != nil {
return res, err
}
for _, cookie := range res.Cookies() {
if cookie.Name == kivik.SessionCookieName {
a.client.authMU.Lock()
a.authExpiry = cookie.Expires.Add(-time.Minute)
a.client.authMU.Unlock()
break
}
}

if res != nil && res.StatusCode == http.StatusUnauthorized {
if cookie := a.Cookie(); cookie != nil {
Expand All @@ -112,14 +101,9 @@ func (a *CookieAuth) authenticate(req *http.Request) error {
if inProg, _ := ctx.Value(authInProgress).(bool); inProg {
return nil
}
if !a.shouldAuth(req) {
return nil
}
a.client.authMU.Lock()
defer a.client.authMU.Unlock()
if c := a.Cookie(); c != nil {
// In case another simultaneous process authenticated successfully first
req.AddCookie(c)
if !a.authExpiry.Before(time.Now()) {
return nil
}
ctx = context.WithValue(ctx, authInProgress, true)
Expand All @@ -132,6 +116,13 @@ func (a *CookieAuth) authenticate(req *http.Request) error {
if _, err := a.client.DoError(ctx, http.MethodPost, "/_session", opts); err != nil {
return err
}
cookies := req.Cookies()
req.Header.Del("Cookie")
for _, cookie := range cookies {
if cookie.Name != kivik.SessionCookieName {
req.AddCookie(cookie)
}
}
if c := a.Cookie(); c != nil {
req.AddCookie(c)
}
Expand Down
87 changes: 0 additions & 87 deletions chttp/cookieauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import (
"net/url"
"strings"
"testing"
"time"

"gitlab.com/flimzy/testy"
"golang.org/x/net/publicsuffix"
Expand Down Expand Up @@ -177,92 +176,6 @@ func (j *dummyJar) SetCookies(_ *url.URL, cookies []*http.Cookie) {
*j = cookies
}

func Test_shouldAuth(t *testing.T) {
type tt struct {
a *CookieAuth
req *http.Request
want bool
}

tests := testy.NewTable()
tests.Add("no session", tt{
a: &CookieAuth{},
req: httptest.NewRequest("GET", "/", nil),
want: true,
})
tests.Add("authed request", func() interface{} {
req := httptest.NewRequest("GET", "/", nil)
req.AddCookie(&http.Cookie{Name: kivik.SessionCookieName})
return tt{
a: &CookieAuth{},
req: req,
want: false,
}
})
tests.Add("valid session", func() interface{} {
c, _ := New("http://example.com/")
c.Jar = &dummyJar{&http.Cookie{
Name: kivik.SessionCookieName,
Expires: time.Now().Add(20 * time.Minute),
}}
a := &CookieAuth{client: c}

return tt{
a: a,
req: httptest.NewRequest("GET", "/", nil),
want: false,
}
})
tests.Add("expired session", func() interface{} {
c, _ := New("http://example.com/")
c.Jar = &dummyJar{&http.Cookie{
Name: kivik.SessionCookieName,
Expires: time.Now().Add(-20 * time.Second),
}}
a := &CookieAuth{client: c}

return tt{
a: a,
req: httptest.NewRequest("GET", "/", nil),
want: true,
}
})
tests.Add("no expiry time", func() interface{} {
c, _ := New("http://example.com/")
c.Jar = &dummyJar{&http.Cookie{
Name: kivik.SessionCookieName,
}}
a := &CookieAuth{client: c}

return tt{
a: a,
req: httptest.NewRequest("GET", "/", nil),
want: false,
}
})
tests.Add("about to expire", func() interface{} {
c, _ := New("http://example.com/")
c.Jar = &dummyJar{&http.Cookie{
Name: kivik.SessionCookieName,
Expires: time.Now().Add(20 * time.Second),
}}
a := &CookieAuth{client: c}

return tt{
a: a,
req: httptest.NewRequest("GET", "/", nil),
want: true,
}
})

tests.Run(t, func(t *testing.T, tt tt) {
got := tt.a.shouldAuth(tt.req)
if got != tt.want {
t.Errorf("Want %t, got %t", tt.want, got)
}
})
}

func Test401Response(t *testing.T) {
var sessCounter, getCounter int
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down

0 comments on commit 68be90d

Please sign in to comment.