diff --git a/chttp/cookieauth.go b/chttp/cookieauth.go index 47b09a44..d4d48cef 100644 --- a/chttp/cookieauth.go +++ b/chttp/cookieauth.go @@ -31,7 +31,8 @@ type CookieAuth struct { client *Client // transport stores the original transport that is overridden by this auth // mechanism - transport http.RoundTripper + transport http.RoundTripper + authExpiry time.Time } var _ Authenticator = &CookieAuth{} @@ -48,26 +49,6 @@ func (a *CookieAuth) Authenticate(c *Client) error { return nil } -// shouldAuth returns true if there is no cookie set, or if it has expired. -func (a *CookieAuth) shouldAuth(req *http.Request) bool { - if _, err := req.Cookie(kivik.SessionCookieName); err == nil { - return false - } - cookie := a.Cookie() - if cookie == nil { - return true - } - if !cookie.Expires.IsZero() { - return cookie.Expires.Before(time.Now().Add(time.Minute)) - } - // If we get here, it means the server did not include an expiry time in - // the session cookie. Some CouchDB configurations do this, but rather than - // re-authenticating for every request, we'll let the session expire. A - // future change might be to make a client-configurable option to set the - // re-authentication timeout. - return false -} - // Cookie returns the current session cookie if found, or nil if not. func (a *CookieAuth) Cookie() *http.Cookie { if a.client == nil { @@ -96,6 +77,14 @@ func (a *CookieAuth) RoundTrip(req *http.Request) (*http.Response, error) { if err != nil { return res, err } + for _, cookie := range res.Cookies() { + if cookie.Name == kivik.SessionCookieName { + a.client.authMU.Lock() + a.authExpiry = cookie.Expires.Add(-time.Minute) + a.client.authMU.Unlock() + break + } + } if res != nil && res.StatusCode == http.StatusUnauthorized { if cookie := a.Cookie(); cookie != nil { @@ -112,14 +101,9 @@ func (a *CookieAuth) authenticate(req *http.Request) error { if inProg, _ := ctx.Value(authInProgress).(bool); inProg { return nil } - if !a.shouldAuth(req) { - return nil - } a.client.authMU.Lock() defer a.client.authMU.Unlock() - if c := a.Cookie(); c != nil { - // In case another simultaneous process authenticated successfully first - req.AddCookie(c) + if !a.authExpiry.Before(time.Now()) { return nil } ctx = context.WithValue(ctx, authInProgress, true) @@ -132,6 +116,13 @@ func (a *CookieAuth) authenticate(req *http.Request) error { if _, err := a.client.DoError(ctx, http.MethodPost, "/_session", opts); err != nil { return err } + cookies := req.Cookies() + req.Header.Del("Cookie") + for _, cookie := range cookies { + if cookie.Name != kivik.SessionCookieName { + req.AddCookie(cookie) + } + } if c := a.Cookie(); c != nil { req.AddCookie(c) } diff --git a/chttp/cookieauth_test.go b/chttp/cookieauth_test.go index c9c99d68..5ccb1930 100644 --- a/chttp/cookieauth_test.go +++ b/chttp/cookieauth_test.go @@ -20,7 +20,6 @@ import ( "net/url" "strings" "testing" - "time" "gitlab.com/flimzy/testy" "golang.org/x/net/publicsuffix" @@ -177,92 +176,6 @@ func (j *dummyJar) SetCookies(_ *url.URL, cookies []*http.Cookie) { *j = cookies } -func Test_shouldAuth(t *testing.T) { - type tt struct { - a *CookieAuth - req *http.Request - want bool - } - - tests := testy.NewTable() - tests.Add("no session", tt{ - a: &CookieAuth{}, - req: httptest.NewRequest("GET", "/", nil), - want: true, - }) - tests.Add("authed request", func() interface{} { - req := httptest.NewRequest("GET", "/", nil) - req.AddCookie(&http.Cookie{Name: kivik.SessionCookieName}) - return tt{ - a: &CookieAuth{}, - req: req, - want: false, - } - }) - tests.Add("valid session", func() interface{} { - c, _ := New("http://example.com/") - c.Jar = &dummyJar{&http.Cookie{ - Name: kivik.SessionCookieName, - Expires: time.Now().Add(20 * time.Minute), - }} - a := &CookieAuth{client: c} - - return tt{ - a: a, - req: httptest.NewRequest("GET", "/", nil), - want: false, - } - }) - tests.Add("expired session", func() interface{} { - c, _ := New("http://example.com/") - c.Jar = &dummyJar{&http.Cookie{ - Name: kivik.SessionCookieName, - Expires: time.Now().Add(-20 * time.Second), - }} - a := &CookieAuth{client: c} - - return tt{ - a: a, - req: httptest.NewRequest("GET", "/", nil), - want: true, - } - }) - tests.Add("no expiry time", func() interface{} { - c, _ := New("http://example.com/") - c.Jar = &dummyJar{&http.Cookie{ - Name: kivik.SessionCookieName, - }} - a := &CookieAuth{client: c} - - return tt{ - a: a, - req: httptest.NewRequest("GET", "/", nil), - want: false, - } - }) - tests.Add("about to expire", func() interface{} { - c, _ := New("http://example.com/") - c.Jar = &dummyJar{&http.Cookie{ - Name: kivik.SessionCookieName, - Expires: time.Now().Add(20 * time.Second), - }} - a := &CookieAuth{client: c} - - return tt{ - a: a, - req: httptest.NewRequest("GET", "/", nil), - want: true, - } - }) - - tests.Run(t, func(t *testing.T, tt tt) { - got := tt.a.shouldAuth(tt.req) - if got != tt.want { - t.Errorf("Want %t, got %t", tt.want, got) - } - }) -} - func Test401Response(t *testing.T) { var sessCounter, getCounter int s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {