diff --git a/docs/api/middleware/session.md b/docs/api/middleware/session.md index 65d23681e7c..9dc68f217d3 100644 --- a/docs/api/middleware/session.md +++ b/docs/api/middleware/session.md @@ -22,6 +22,7 @@ func (s *Session) Get(key string) interface{} func (s *Session) Set(key string, val interface{}) func (s *Session) Delete(key string) func (s *Session) Destroy() error +func (s *Session) Reset() error func (s *Session) Regenerate() error func (s *Session) Save() error func (s *Session) Fresh() bool diff --git a/middleware/session/session.go b/middleware/session/session.go index fab7e4867bd..ebe00f6057b 100644 --- a/middleware/session/session.go +++ b/middleware/session/session.go @@ -125,6 +125,33 @@ func (s *Session) Regenerate() error { return nil } +// Reset generates a new session id, deletes the old one from storage, and resets the associated data +func (s *Session) Reset() error { + // Reset local data + if s.data != nil { + s.data.Reset() + } + // Reset byte buffer + if s.byteBuffer != nil { + s.byteBuffer.Reset() + } + // Reset expiration + s.exp = 0 + + // Delete old id from storage + if err := s.config.Storage.Delete(s.id); err != nil { + return err + } + + // Expire session + s.delSession() + + // Generate a new session, and set session.fresh to true + s.refresh() + + return nil +} + // refresh generates a new session, and set session.fresh to be true func (s *Session) refresh() { // Create a new id diff --git a/middleware/session/session_test.go b/middleware/session/session_test.go index fd1c686ace7..f153e33a2dd 100644 --- a/middleware/session/session_test.go +++ b/middleware/session/session_test.go @@ -324,11 +324,11 @@ func Test_Session_Save_Expiration(t *testing.T) { }) } -// go test -run Test_Session_Reset -func Test_Session_Reset(t *testing.T) { +// go test -run Test_Session_Destroy +func Test_Session_Destroy(t *testing.T) { t.Parallel() - t.Run("reset from cookie", func(t *testing.T) { + t.Run("destroy from cookie", func(t *testing.T) { t.Parallel() // session store store := New() @@ -347,7 +347,7 @@ func Test_Session_Reset(t *testing.T) { utils.AssertEqual(t, nil, name) }) - t.Run("reset from header", func(t *testing.T) { + t.Run("destroy from header", func(t *testing.T) { t.Parallel() // session store store := New(Config{ @@ -461,6 +461,76 @@ func Test_Session_Deletes_Single_Key(t *testing.T) { utils.AssertEqual(t, nil, sess.Get("id")) } +// go test -run Test_Session_Reset +func Test_Session_Reset(t *testing.T) { + t.Parallel() + // fiber instance + app := fiber.New() + t.Run("reset session data and id, and set fresh to be true", func(t *testing.T) { + // session store + store := New() + // a random session uuid + originalSessionUUIDString := "" + // fiber context + ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(ctx) + + // now the session is in the storage + freshSession, err := store.Get(ctx) + utils.AssertEqual(t, nil, err) + + originalSessionUUIDString = freshSession.ID() + + // set a value + freshSession.Set("name", "fenny") + freshSession.Set("email", "fenny@example.com") + + err = freshSession.Save() + utils.AssertEqual(t, nil, err) + + // set cookie + ctx.Request().Header.SetCookie(store.sessionName, originalSessionUUIDString) + + // as the session is in the storage, session.fresh should be false + acquiredSession, err := store.Get(ctx) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, false, acquiredSession.Fresh()) + + err = acquiredSession.Reset() + utils.AssertEqual(t, nil, err) + + utils.AssertEqual(t, false, acquiredSession.ID() == originalSessionUUIDString) + + // acquiredSession.fresh should be true after resetting + utils.AssertEqual(t, true, acquiredSession.Fresh()) + + // Check that the session data has been reset + keys := acquiredSession.Keys() + utils.AssertEqual(t, []string{}, keys) + + // Set a new value for 'name' and check that it's updated + acquiredSession.Set("name", "john") + utils.AssertEqual(t, "john", acquiredSession.Get("name")) + utils.AssertEqual(t, nil, acquiredSession.Get("email")) + + // Save after resetting + err = acquiredSession.Save() + utils.AssertEqual(t, nil, err) + + // requesting entirely new context to prevent falsy tests + ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(ctx) + + // Check that the session id is not in the header or cookie anymore + utils.AssertEqual(t, "", string(ctx.Response().Header.Peek(store.sessionName))) + utils.AssertEqual(t, "", string(ctx.Request().Header.Peek(store.sessionName))) + + // But the new session id should be in the header or cookie + utils.AssertEqual(t, acquiredSession.ID(), string(ctx.Response().Header.Peek(store.sessionName))) + utils.AssertEqual(t, acquiredSession.ID(), string(ctx.Request().Header.Peek(store.sessionName))) + }) +} + // go test -run Test_Session_Regenerate // Regression: https://github.com/gofiber/fiber/issues/1395 func Test_Session_Regenerate(t *testing.T) {