From 6e76847f88aae9dfb233625f708f9a51ea99f7e4 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Tue, 28 May 2024 17:31:18 -0300 Subject: [PATCH 01/79] feat!(middleware/session): re-write session middleware with handler --- middleware/csrf/session_manager.go | 58 ++++++---- middleware/session/config.go | 10 +- middleware/session/middleware.go | 176 +++++++++++++++++++++++++++++ middleware/session/session.go | 34 +++--- 4 files changed, 236 insertions(+), 42 deletions(-) create mode 100644 middleware/session/middleware.go diff --git a/middleware/csrf/session_manager.go b/middleware/csrf/session_manager.go index 87172eb838..21da1e54c1 100644 --- a/middleware/csrf/session_manager.go +++ b/middleware/csrf/session_manager.go @@ -4,7 +4,6 @@ import ( "time" "github.com/gofiber/fiber/v3" - "github.com/gofiber/fiber/v3/log" "github.com/gofiber/fiber/v3/middleware/session" ) @@ -27,11 +26,22 @@ func newSessionManager(s *session.Store, k string) *sessionManager { // get token from session func (m *sessionManager) getRaw(c fiber.Ctx, key string, raw []byte) []byte { - sess, err := m.session.Get(c) - if err != nil { - return nil + sess := session.FromContext(c) + var token Token + var ok bool + + if sess != nil { + token, ok = sess.Get(m.key).(Token) + } else { + // Try to get the session from the store + storeSess, err := m.session.Get(c) + if err != nil { + // Handle error + return nil + } + token, ok = storeSess.Get(m.key).(Token) } - token, ok := sess.Get(m.key).(Token) + if ok { if token.Expiration.Before(time.Now()) || key != token.Key || !compareTokens(raw, token.Raw) { return nil @@ -44,25 +54,33 @@ func (m *sessionManager) getRaw(c fiber.Ctx, key string, raw []byte) []byte { // set token in session func (m *sessionManager) setRaw(c fiber.Ctx, key string, raw []byte, exp time.Duration) { - sess, err := m.session.Get(c) - if err != nil { - return - } - // the key is crucial in crsf and sometimes a reference to another value which can be reused later(pool/unsafe values concept), so a copy is made here - sess.Set(m.key, &Token{key, raw, time.Now().Add(exp)}) - if err := sess.Save(); err != nil { - log.Warn("csrf: failed to save session: ", err) + sess := session.FromContext(c) + if sess != nil { + // the key is crucial in crsf and sometimes a reference to another value which can be reused later(pool/unsafe values concept), so a copy is made here + sess.Set(m.key, &Token{key, raw, time.Now().Add(exp)}) + } else { + // Try to get the session from the store + storeSess, err := m.session.Get(c) + if err != nil { + // Handle error + return + } + storeSess.Set(m.key, &Token{key, raw, time.Now().Add(exp)}) } } // delete token from session func (m *sessionManager) delRaw(c fiber.Ctx) { - sess, err := m.session.Get(c) - if err != nil { - return - } - sess.Delete(m.key) - if err := sess.Save(); err != nil { - log.Warn("csrf: failed to save session: ", err) + sess := session.FromContext(c) + if sess != nil { + sess.Delete(m.key) + } else { + // Try to get the session from the store + storeSess, err := m.session.Get(c) + if err != nil { + // Handle error + return + } + storeSess.Delete(m.key) } } diff --git a/middleware/session/config.go b/middleware/session/config.go index b98eeb2553..e3204d5cef 100644 --- a/middleware/session/config.go +++ b/middleware/session/config.go @@ -10,9 +10,9 @@ import ( // Config defines the config for middleware. type Config struct { - // Allowed session duration + // Allowed session idle duration // Optional. Default value 24 * time.Hour - Expiration time.Duration + IdleTimeout time.Duration // Storage interface to store the session data // Optional. Default value memory.New() @@ -70,7 +70,7 @@ const ( // ConfigDefault is the default config var ConfigDefault = Config{ - Expiration: 24 * time.Hour, + IdleTimeout: 24 * time.Hour, KeyLookup: "cookie:session_id", KeyGenerator: utils.UUIDv4, source: "cookie", @@ -88,8 +88,8 @@ func configDefault(config ...Config) Config { cfg := config[0] // Set default values - if int(cfg.Expiration.Seconds()) <= 0 { - cfg.Expiration = ConfigDefault.Expiration + if int(cfg.IdleTimeout.Seconds()) <= 0 { + cfg.IdleTimeout = ConfigDefault.IdleTimeout } if cfg.KeyLookup == "" { cfg.KeyLookup = ConfigDefault.KeyLookup diff --git a/middleware/session/middleware.go b/middleware/session/middleware.go new file mode 100644 index 0000000000..e11ffdc18b --- /dev/null +++ b/middleware/session/middleware.go @@ -0,0 +1,176 @@ +package session + +import ( + "sync" + + "github.com/gofiber/fiber/v3" + "github.com/gofiber/fiber/v3/log" +) + +// key for looking up session middleware in request context +const key = 0 + +// Session defines the session middleware configuration +type MiddlewareConfig struct { + // Next defines a function to skip this middleware when returned true. + // + // Optional. Default: nil + Next func(c fiber.Ctx) bool + + // Store defines the session store + // + // Required. + Store *Store + + // ErrorHandler defines a function which is executed for errors + // + // Optional. Default: nil + ErrorHandler func(*fiber.Ctx, error) +} + +type Middleware struct { + config MiddlewareConfig + Session *Session + ctx *fiber.Ctx + hasChanged bool + mu sync.RWMutex +} + +// Middleware pool +var middlewarePool = &sync.Pool{ + New: func() interface{} { + return &Middleware{} + }, +} + +// Session is a middleware to manage session state +// +// Session middleware manages common session state between requests. +// This middleware is dependent on the session store, which is responsible for +// storing the session data. +func NewMiddleware(config MiddlewareConfig) fiber.Handler { + return func(c fiber.Ctx) error { + // Don't execute middleware if Next returns true + if config.Next != nil && config.Next(c) { + return c.Next() + } + + // Get the session + session, err := config.Store.Get(c) + if err != nil { + return err + } + + // get a middleware from the pool + m := acquireMiddleware() + m.config = config + m.Session = session + m.ctx = &c + + // Store the middleware in the context + c.Locals(key, m) + + // Continue stack + stackErr := c.Next() + + // Save the session + // This is done after the response is sent to the client + // It allows us to modify the session data during the request + // Without having to worry about calling Save() + // + // It will also extend the session idle timeout automatically. + if err := session.Save(); err != nil { + if config.ErrorHandler != nil { + config.ErrorHandler(&c, err) + } else { + log.Errorf("session: %v", err) + } + } + + // release the middleware back to the pool + releaseMiddleware(m) + + return stackErr + } +} + +// acquireMiddleware returns a new Middleware from the pool +func acquireMiddleware() *Middleware { + return middlewarePool.Get().(*Middleware) +} + +// releaseMiddleware returns a Middleware to the pool +func releaseMiddleware(m *Middleware) { + m.config = MiddlewareConfig{} + m.Session = nil + m.ctx = nil + middlewarePool.Put(m) +} + +// FromContext returns the Middleware from the fiber context +func FromContext(c fiber.Ctx) *Middleware { + return c.Locals(key).(*Middleware) +} + +func (m *Middleware) Set(key string, value any) { + m.mu.Lock() + defer m.mu.Unlock() + + m.Session.Set(key, value) + m.hasChanged = true +} + +func (m *Middleware) Get(key string) any { + m.mu.RLock() + defer m.mu.RUnlock() + + return m.Session.Get(key) +} + +func (m *Middleware) Delete(key string) { + m.mu.Lock() + defer m.mu.Unlock() + + m.Session.Delete(key) + m.hasChanged = true +} + +func (m *Middleware) Destroy() error { + m.mu.Lock() + defer m.mu.Unlock() + + err := m.Session.Destroy() + m.reaquireSession() + return err +} + +func (m *Middleware) Fresh() bool { + return m.Session.Fresh() +} + +func (m *Middleware) ID() string { + return m.Session.ID() +} + +func (m *Middleware) Reset() error { + m.mu.Lock() + defer m.mu.Unlock() + + err := m.Session.Reset() + m.hasChanged = true + return err + +} + +func (m *Middleware) reaquireSession() { + if m.ctx == nil { + return + } + + session, err := m.config.Store.Get(*m.ctx) + if err != nil { + m.config.ErrorHandler(m.ctx, err) + } + m.Session = session + m.hasChanged = false +} diff --git a/middleware/session/session.go b/middleware/session/session.go index c257343968..3590f5384d 100644 --- a/middleware/session/session.go +++ b/middleware/session/session.go @@ -13,13 +13,13 @@ import ( ) type Session struct { - id string // session id - fresh bool // if new session - ctx fiber.Ctx // fiber context - config *Store // store configuration - data *data // key value data - byteBuffer *bytes.Buffer // byte buffer for the en- and decode - exp time.Duration // expiration of this session + id string // session id + fresh bool // if new session + ctx fiber.Ctx // fiber context + config *Store // store configuration + data *data // key value data + byteBuffer *bytes.Buffer // byte buffer for the en- and decode + idleTimeout time.Duration // idleTimeout of this session } var sessionPool = sync.Pool{ @@ -42,7 +42,7 @@ func acquireSession() *Session { func releaseSession(s *Session) { s.id = "" - s.exp = 0 + s.idleTimeout = 0 s.ctx = nil s.config = nil if s.data != nil { @@ -135,7 +135,7 @@ func (s *Session) Reset() error { s.byteBuffer.Reset() } // Reset expiration - s.exp = 0 + s.idleTimeout = 0 // Delete old id from storage if err := s.config.Storage.Delete(s.id); err != nil { @@ -167,9 +167,9 @@ func (s *Session) Save() error { return nil } - // Check if session has your own expiration, otherwise use default value - if s.exp <= 0 { - s.exp = s.config.Expiration + // Check if session has your own idle timeout, otherwise use default value + if s.idleTimeout <= 0 { + s.idleTimeout = s.config.IdleTimeout } // Update client cookie @@ -189,7 +189,7 @@ func (s *Session) Save() error { copy(encodedBytes, s.byteBuffer.Bytes()) // pass copied bytes with session id to provider - if err := s.config.Storage.Set(s.id, encodedBytes, s.exp); err != nil { + if err := s.config.Storage.Set(s.id, encodedBytes, s.idleTimeout); err != nil { return err } @@ -209,8 +209,8 @@ func (s *Session) Keys() []string { } // SetExpiry sets a specific expiration for this session -func (s *Session) SetExpiry(exp time.Duration) { - s.exp = exp +func (s *Session) SetIdleTimeout(idleTimeout time.Duration) { + s.idleTimeout = idleTimeout } func (s *Session) setSession() { @@ -226,8 +226,8 @@ func (s *Session) setSession() { // Cookies are also session cookies if they do not specify the Expires or Max-Age attribute. // refer: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie if !s.config.CookieSessionOnly { - fcookie.SetMaxAge(int(s.exp.Seconds())) - fcookie.SetExpire(time.Now().Add(s.exp)) + fcookie.SetMaxAge(int(s.idleTimeout.Seconds())) + fcookie.SetExpire(time.Now().Add(s.idleTimeout)) } fcookie.SetSecure(s.config.CookieSecure) fcookie.SetHTTPOnly(s.config.CookieHTTPOnly) From ac9a0287d288c779f7cdae152d32b5c53e97dfc6 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Tue, 28 May 2024 18:13:53 -0300 Subject: [PATCH 02/79] test(middleware/session): refactor to IdleTimeout --- middleware/session/session_test.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/middleware/session/session_test.go b/middleware/session/session_test.go index 02bd52d4e2..cc44137c06 100644 --- a/middleware/session/session_test.go +++ b/middleware/session/session_test.go @@ -365,7 +365,7 @@ func Test_Session_Save_Expiration(t *testing.T) { sess.Set("name", "john") // expire this session in 5 seconds - sess.SetExpiry(sessionDuration) + sess.SetIdleTimeout(sessionDuration) // save session err = sess.Save() @@ -443,12 +443,12 @@ func Test_Session_Destroy(t *testing.T) { func Test_Session_Custom_Config(t *testing.T) { t.Parallel() - store := New(Config{Expiration: time.Hour, KeyGenerator: func() string { return "very random" }}) - require.Equal(t, time.Hour, store.Expiration) + store := New(Config{IdleTimeout: time.Hour, KeyGenerator: func() string { return "very random" }}) + require.Equal(t, time.Hour, store.IdleTimeout) require.Equal(t, "very random", store.KeyGenerator()) - store = New(Config{Expiration: 0}) - require.Equal(t, ConfigDefault.Expiration, store.Expiration) + store = New(Config{IdleTimeout: 0}) + require.Equal(t, ConfigDefault.IdleTimeout, store.IdleTimeout) } // go test -run Test_Session_Cookie From 81f6789676f3f2a1c4a7ff488e2cb962ec263f95 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Tue, 28 May 2024 18:24:39 -0300 Subject: [PATCH 03/79] fix: lint errors --- middleware/session/middleware.go | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/middleware/session/middleware.go b/middleware/session/middleware.go index e11ffdc18b..73116e847d 100644 --- a/middleware/session/middleware.go +++ b/middleware/session/middleware.go @@ -1,6 +1,8 @@ package session import ( + "errors" + "fmt" "sync" "github.com/gofiber/fiber/v3" @@ -38,7 +40,7 @@ type Middleware struct { // Middleware pool var middlewarePool = &sync.Pool{ - New: func() interface{} { + New: func() any { return &Middleware{} }, } @@ -94,9 +96,15 @@ func NewMiddleware(config MiddlewareConfig) fiber.Handler { } } +var ErrTypeAssertionFailed = errors.New("failed to type-assert to *Middleware") + // acquireMiddleware returns a new Middleware from the pool func acquireMiddleware() *Middleware { - return middlewarePool.Get().(*Middleware) + middleware, ok := middlewarePool.Get().(*Middleware) + if !ok { + panic(fmt.Errorf("%w", ErrTypeAssertionFailed)) + } + return middleware } // releaseMiddleware returns a Middleware to the pool @@ -109,7 +117,12 @@ func releaseMiddleware(m *Middleware) { // FromContext returns the Middleware from the fiber context func FromContext(c fiber.Ctx) *Middleware { - return c.Locals(key).(*Middleware) + m, ok := c.Locals(key).(*Middleware) + if !ok { + log.Warn("session: Session middleware not registered. See https://docs.gofiber.io/middleware/session") + return nil + } + return m } func (m *Middleware) Set(key string, value any) { @@ -159,7 +172,6 @@ func (m *Middleware) Reset() error { err := m.Session.Reset() m.hasChanged = true return err - } func (m *Middleware) reaquireSession() { From 28790cb34d6a5d925fc834607ad81365ca1e245f Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Tue, 28 May 2024 18:29:20 -0300 Subject: [PATCH 04/79] test: Save session after setting or deleting raw data in CSRF middleware --- middleware/csrf/session_manager.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/middleware/csrf/session_manager.go b/middleware/csrf/session_manager.go index 21da1e54c1..51685b43b1 100644 --- a/middleware/csrf/session_manager.go +++ b/middleware/csrf/session_manager.go @@ -4,6 +4,7 @@ import ( "time" "github.com/gofiber/fiber/v3" + "github.com/gofiber/fiber/v3/log" "github.com/gofiber/fiber/v3/middleware/session" ) @@ -66,6 +67,9 @@ func (m *sessionManager) setRaw(c fiber.Ctx, key string, raw []byte, exp time.Du return } storeSess.Set(m.key, &Token{key, raw, time.Now().Add(exp)}) + if err := storeSess.Save(); err != nil { + log.Warn("csrf: failed to save session: ", err) + } } } @@ -82,5 +86,8 @@ func (m *sessionManager) delRaw(c fiber.Ctx) { return } storeSess.Delete(m.key) + if err := storeSess.Save(); err != nil { + log.Warn("csrf: failed to save session: ", err) + } } } From 7ffae3d037e514f0c9d771ba5c02fbe4e3db1247 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Thu, 30 May 2024 13:51:42 -0300 Subject: [PATCH 05/79] Update middleware/session/middleware.go Co-authored-by: Renan Bastos --- middleware/session/middleware.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/middleware/session/middleware.go b/middleware/session/middleware.go index 73116e847d..8a52b0cb5e 100644 --- a/middleware/session/middleware.go +++ b/middleware/session/middleware.go @@ -102,7 +102,7 @@ var ErrTypeAssertionFailed = errors.New("failed to type-assert to *Middleware") func acquireMiddleware() *Middleware { middleware, ok := middlewarePool.Get().(*Middleware) if !ok { - panic(fmt.Errorf("%w", ErrTypeAssertionFailed)) + panic(ErrTypeAssertionFailed.Error()) } return middleware } From 68f2739ada537a18d2cb43fece623c36059eb74f Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Thu, 30 May 2024 14:24:15 -0300 Subject: [PATCH 06/79] fix: mutex and globals order --- middleware/session/middleware.go | 31 +++++++++++++++---------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/middleware/session/middleware.go b/middleware/session/middleware.go index 8a52b0cb5e..8aa2fa4401 100644 --- a/middleware/session/middleware.go +++ b/middleware/session/middleware.go @@ -2,16 +2,12 @@ package session import ( "errors" - "fmt" "sync" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/log" ) -// key for looking up session middleware in request context -const key = 0 - // Session defines the session middleware configuration type MiddlewareConfig struct { // Next defines a function to skip this middleware when returned true. @@ -34,16 +30,23 @@ type Middleware struct { config MiddlewareConfig Session *Session ctx *fiber.Ctx - hasChanged bool + hasChanged bool // TODO: use this to optimize interaction with the session store mu sync.RWMutex } -// Middleware pool -var middlewarePool = &sync.Pool{ - New: func() any { - return &Middleware{} - }, -} +// key for looking up session middleware in request context +const key = 0 + +var ( + // ErrTypeAssertionFailed is returned when the type assertion failed + ErrTypeAssertionFailed = errors.New("failed to type-assert to *Middleware") + + middlewarePool = &sync.Pool{ + New: func() any { + return &Middleware{} + }, + } +) // Session is a middleware to manage session state // @@ -96,8 +99,6 @@ func NewMiddleware(config MiddlewareConfig) fiber.Handler { } } -var ErrTypeAssertionFailed = errors.New("failed to type-assert to *Middleware") - // acquireMiddleware returns a new Middleware from the pool func acquireMiddleware() *Middleware { middleware, ok := middlewarePool.Get().(*Middleware) @@ -134,9 +135,7 @@ func (m *Middleware) Set(key string, value any) { } func (m *Middleware) Get(key string) any { - m.mu.RLock() - defer m.mu.RUnlock() - + // no need to lock here, since the session has its own mutex return m.Session.Get(key) } From 92e687707951fddba53e272a5f7c0eed089bcf3f Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Tue, 4 Jun 2024 12:14:43 -0300 Subject: [PATCH 07/79] feat: Re-Add read lock to session Get method --- middleware/session/middleware.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/middleware/session/middleware.go b/middleware/session/middleware.go index 8aa2fa4401..6acc3a67de 100644 --- a/middleware/session/middleware.go +++ b/middleware/session/middleware.go @@ -135,7 +135,9 @@ func (m *Middleware) Set(key string, value any) { } func (m *Middleware) Get(key string) any { - // no need to lock here, since the session has its own mutex + m.mu.RLock() + defer m.mu.RUnlock() + return m.Session.Get(key) } From 239db002e56a641d4185a6228fb64e39251abc1e Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Sat, 15 Jun 2024 15:16:15 -0300 Subject: [PATCH 08/79] feat: Migrate New() to return middleware --- middleware/session/config.go | 15 ++++++++++ middleware/session/middleware.go | 44 +++++++++++++++------------- middleware/session/session.go | 6 ++++ middleware/session/session_test.go | 46 +++++++++++++++--------------- middleware/session/store.go | 18 ++++++++++-- middleware/session/store_test.go | 10 +++---- 6 files changed, 89 insertions(+), 50 deletions(-) diff --git a/middleware/session/config.go b/middleware/session/config.go index e3204d5cef..a411e6e603 100644 --- a/middleware/session/config.go +++ b/middleware/session/config.go @@ -10,6 +10,21 @@ import ( // Config defines the config for middleware. type Config struct { + // Next defines a function to skip this middleware when returned true. + // + // Optional. Default: nil + Next func(c fiber.Ctx) bool + + // Store defines the session store + // + // Required. + Store *Store + + // ErrorHandler defines a function which is executed for errors + // + // Optional. Default: nil + ErrorHandler func(*fiber.Ctx, error) + // Allowed session idle duration // Optional. Default value 24 * time.Hour IdleTimeout time.Duration diff --git a/middleware/session/middleware.go b/middleware/session/middleware.go index 6acc3a67de..77e09e3e33 100644 --- a/middleware/session/middleware.go +++ b/middleware/session/middleware.go @@ -9,25 +9,9 @@ import ( ) // Session defines the session middleware configuration -type MiddlewareConfig struct { - // Next defines a function to skip this middleware when returned true. - // - // Optional. Default: nil - Next func(c fiber.Ctx) bool - - // Store defines the session store - // - // Required. - Store *Store - - // ErrorHandler defines a function which is executed for errors - // - // Optional. Default: nil - ErrorHandler func(*fiber.Ctx, error) -} type Middleware struct { - config MiddlewareConfig + config Config Session *Session ctx *fiber.Ctx hasChanged bool // TODO: use this to optimize interaction with the session store @@ -53,8 +37,18 @@ var ( // Session middleware manages common session state between requests. // This middleware is dependent on the session store, which is responsible for // storing the session data. -func NewMiddleware(config MiddlewareConfig) fiber.Handler { - return func(c fiber.Ctx) error { +func New(config Config) fiber.Handler { + handler, _ := NewWithStore(config) + return handler +} + +// NewWithStore returns a new session middleware with the given store +func NewWithStore(config Config) (fiber.Handler, *Store) { + if config.Store == nil { + config.Store = newStore(config) + } + + handler := func(c fiber.Ctx) error { // Don't execute middleware if Next returns true if config.Next != nil && config.Next(c) { return c.Next() @@ -97,6 +91,8 @@ func NewMiddleware(config MiddlewareConfig) fiber.Handler { return stackErr } + + return handler, config.Store } // acquireMiddleware returns a new Middleware from the pool @@ -110,7 +106,7 @@ func acquireMiddleware() *Middleware { // releaseMiddleware returns a Middleware to the pool func releaseMiddleware(m *Middleware) { - m.config = MiddlewareConfig{} + m.config = Config{} m.Session = nil m.ctx = nil middlewarePool.Put(m) @@ -187,3 +183,11 @@ func (m *Middleware) reaquireSession() { m.Session = session m.hasChanged = false } + +// Store returns the session store +func (m *Middleware) Store() *Store { + // TODO: Ensure that session.Save() can not be called + // on the store directly if the session is the same as the one in the middleware + // context. This is to prevent the session Save from invalidating the session. + return m.config.Store +} diff --git a/middleware/session/session.go b/middleware/session/session.go index 3590f5384d..a3b94bf6b0 100644 --- a/middleware/session/session.go +++ b/middleware/session/session.go @@ -162,6 +162,12 @@ func (s *Session) refresh() { // Save will update the storage and client cookie func (s *Session) Save() error { + // If the session is being used in the handler, it should not be saved + if _, ok := s.ctx.Locals(key).(*Middleware); ok { + // Session is in use, so we do nothing and return + return nil + } + // Better safe than sorry if s.data == nil { return nil diff --git a/middleware/session/session_test.go b/middleware/session/session_test.go index cc44137c06..61da6a1502 100644 --- a/middleware/session/session_test.go +++ b/middleware/session/session_test.go @@ -15,7 +15,7 @@ func Test_Session(t *testing.T) { t.Parallel() // session store - store := New() + store := newStore() // fiber instance app := fiber.New() @@ -98,7 +98,7 @@ func Test_Session_Types(t *testing.T) { t.Parallel() // session store - store := New() + store := newStore() // fiber instance app := fiber.New() @@ -265,7 +265,7 @@ func Test_Session_Types(t *testing.T) { func Test_Session_Store_Reset(t *testing.T) { t.Parallel() // session store - store := New() + store := newStore() // fiber instance app := fiber.New() // fiber context @@ -299,7 +299,7 @@ func Test_Session_Save(t *testing.T) { t.Run("save to cookie", func(t *testing.T) { t.Parallel() // session store - store := New() + store := newStore() // fiber instance app := fiber.New() // fiber context @@ -319,7 +319,7 @@ func Test_Session_Save(t *testing.T) { t.Run("save to header", func(t *testing.T) { t.Parallel() // session store - store := New(Config{ + store := newStore(Config{ KeyLookup: "header:session_id", }) // fiber instance @@ -350,7 +350,7 @@ func Test_Session_Save_Expiration(t *testing.T) { const sessionDuration = 5 * time.Second // session store - store := New() + store := newStore() // fiber instance app := fiber.New() // fiber context @@ -393,7 +393,7 @@ func Test_Session_Destroy(t *testing.T) { t.Run("destroy from cookie", func(t *testing.T) { t.Parallel() // session store - store := New() + store := newStore() // fiber instance app := fiber.New() // fiber context @@ -413,7 +413,7 @@ func Test_Session_Destroy(t *testing.T) { t.Run("destroy from header", func(t *testing.T) { t.Parallel() // session store - store := New(Config{ + store := newStore(Config{ KeyLookup: "header:session_id", }) // fiber instance @@ -443,11 +443,11 @@ func Test_Session_Destroy(t *testing.T) { func Test_Session_Custom_Config(t *testing.T) { t.Parallel() - store := New(Config{IdleTimeout: time.Hour, KeyGenerator: func() string { return "very random" }}) + store := newStore(Config{IdleTimeout: time.Hour, KeyGenerator: func() string { return "very random" }}) require.Equal(t, time.Hour, store.IdleTimeout) require.Equal(t, "very random", store.KeyGenerator()) - store = New(Config{IdleTimeout: 0}) + store = newStore(Config{IdleTimeout: 0}) require.Equal(t, ConfigDefault.IdleTimeout, store.IdleTimeout) } @@ -455,7 +455,7 @@ func Test_Session_Custom_Config(t *testing.T) { func Test_Session_Cookie(t *testing.T) { t.Parallel() // session store - store := New() + store := newStore() // fiber instance app := fiber.New() // fiber context @@ -474,7 +474,7 @@ func Test_Session_Cookie(t *testing.T) { // go test -run Test_Session_Cookie_In_Response func Test_Session_Cookie_In_Response(t *testing.T) { t.Parallel() - store := New() + store := newStore() app := fiber.New() // fiber context @@ -501,7 +501,7 @@ func Test_Session_Cookie_In_Response(t *testing.T) { // Regression: https://github.com/gofiber/fiber/issues/1365 func Test_Session_Deletes_Single_Key(t *testing.T) { t.Parallel() - store := New() + store := newStore() app := fiber.New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) @@ -532,7 +532,7 @@ func Test_Session_Reset(t *testing.T) { app := fiber.New() // session store - store := New() + store := newStore() t.Run("reset session data and id, and set fresh to be true", func(t *testing.T) { t.Parallel() @@ -599,7 +599,7 @@ func Test_Session_Regenerate(t *testing.T) { t.Run("set fresh to be true when regenerating a session", func(t *testing.T) { t.Parallel() // session store - store := New() + store := newStore() // a random session uuid originalSessionUUIDString := "" // fiber context @@ -636,7 +636,7 @@ func Test_Session_Regenerate(t *testing.T) { // go test -v -run=^$ -bench=Benchmark_Session -benchmem -count=4 func Benchmark_Session(b *testing.B) { b.Run("default", func(b *testing.B) { - app, store := fiber.New(), New() + app, store := fiber.New(), newStore() c := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(c) c.Request().Header.SetCookie(store.sessionName, "12356789") @@ -652,7 +652,7 @@ func Benchmark_Session(b *testing.B) { b.Run("storage", func(b *testing.B) { app := fiber.New() - store := New(Config{ + store := newStore(Config{ Storage: memory.New(), }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) @@ -672,7 +672,7 @@ func Benchmark_Session(b *testing.B) { // go test -v -run=^$ -bench=Benchmark_Session_Parallel -benchmem -count=4 func Benchmark_Session_Parallel(b *testing.B) { b.Run("default", func(b *testing.B) { - app, store := fiber.New(), New() + app, store := fiber.New(), newStore() b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { @@ -690,7 +690,7 @@ func Benchmark_Session_Parallel(b *testing.B) { b.Run("storage", func(b *testing.B) { app := fiber.New() - store := New(Config{ + store := newStore(Config{ Storage: memory.New(), }) b.ReportAllocs() @@ -712,7 +712,7 @@ func Benchmark_Session_Parallel(b *testing.B) { // go test -v -run=^$ -bench=Benchmark_Session_Asserted -benchmem -count=4 func Benchmark_Session_Asserted(b *testing.B) { b.Run("default", func(b *testing.B) { - app, store := fiber.New(), New() + app, store := fiber.New(), newStore() c := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(c) c.Request().Header.SetCookie(store.sessionName, "12356789") @@ -730,7 +730,7 @@ func Benchmark_Session_Asserted(b *testing.B) { b.Run("storage", func(b *testing.B) { app := fiber.New() - store := New(Config{ + store := newStore(Config{ Storage: memory.New(), }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) @@ -752,7 +752,7 @@ func Benchmark_Session_Asserted(b *testing.B) { // go test -v -run=^$ -bench=Benchmark_Session_Asserted_Parallel -benchmem -count=4 func Benchmark_Session_Asserted_Parallel(b *testing.B) { b.Run("default", func(b *testing.B) { - app, store := fiber.New(), New() + app, store := fiber.New(), newStore() b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { @@ -771,7 +771,7 @@ func Benchmark_Session_Asserted_Parallel(b *testing.B) { b.Run("storage", func(b *testing.B) { app := fiber.New() - store := New(Config{ + store := newStore(Config{ Storage: memory.New(), }) b.ReportAllocs() diff --git a/middleware/session/store.go b/middleware/session/store.go index dbca801808..85f955d751 100644 --- a/middleware/session/store.go +++ b/middleware/session/store.go @@ -13,7 +13,10 @@ import ( ) // ErrEmptySessionID is an error that occurs when the session ID is empty. -var ErrEmptySessionID = errors.New("session id cannot be empty") +var ( + ErrEmptySessionID = errors.New("session id cannot be empty") + ErrSessionAlreadyLoadedByMiddleware = errors.New("session already loaded by middleware") +) type Store struct { Config @@ -21,7 +24,7 @@ type Store struct { var mux sync.Mutex -func New(config ...Config) *Store { +func newStore(config ...Config) *Store { // Set default config cfg := configDefault(config...) @@ -41,7 +44,18 @@ func (*Store) RegisterType(i any) { } // Get will get/create a session +// +// This function will return an ErrSessionAlreadyLoadedByMiddleware if +// the session is already loaded by the middleware func (s *Store) Get(c fiber.Ctx) (*Session, error) { + // If session is already loaded in the context, + // it should not be loaded again + _, ok := c.Locals(key).(*Middleware) + if ok { + return nil, ErrSessionAlreadyLoadedByMiddleware + } + + // Get session based on context var fresh bool loadData := true diff --git a/middleware/session/store_test.go b/middleware/session/store_test.go index 1373827899..b2bc51a0fe 100644 --- a/middleware/session/store_test.go +++ b/middleware/session/store_test.go @@ -20,7 +20,7 @@ func Test_Store_getSessionID(t *testing.T) { t.Run("from cookie", func(t *testing.T) { t.Parallel() // session store - store := New() + store := newStore() // fiber context ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) @@ -33,7 +33,7 @@ func Test_Store_getSessionID(t *testing.T) { t.Run("from header", func(t *testing.T) { t.Parallel() // session store - store := New(Config{ + store := newStore(Config{ KeyLookup: "header:session_id", }) // fiber context @@ -48,7 +48,7 @@ func Test_Store_getSessionID(t *testing.T) { t.Run("from url query", func(t *testing.T) { t.Parallel() // session store - store := New(Config{ + store := newStore(Config{ KeyLookup: "query:session_id", }) // fiber context @@ -71,7 +71,7 @@ func Test_Store_Get(t *testing.T) { t.Run("session should persisted even session is invalid", func(t *testing.T) { t.Parallel() // session store - store := New() + store := newStore() // fiber context ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) @@ -91,7 +91,7 @@ func Test_Store_DeleteSession(t *testing.T) { // fiber instance app := fiber.New() // session store - store := New() + store := newStore() // fiber context ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) From 0b93c5cd1a738cd7dda6b74065db5e950febcf3f Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Sat, 15 Jun 2024 15:38:49 -0300 Subject: [PATCH 09/79] chore: Refactor session middleware to improve session handling --- middleware/session/middleware.go | 5 +---- middleware/session/session.go | 4 ++++ 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/middleware/session/middleware.go b/middleware/session/middleware.go index 77e09e3e33..c168ae37e0 100644 --- a/middleware/session/middleware.go +++ b/middleware/session/middleware.go @@ -78,7 +78,7 @@ func NewWithStore(config Config) (fiber.Handler, *Store) { // Without having to worry about calling Save() // // It will also extend the session idle timeout automatically. - if err := session.Save(); err != nil { + if err := session.save(); err != nil { if config.ErrorHandler != nil { config.ErrorHandler(&c, err) } else { @@ -186,8 +186,5 @@ func (m *Middleware) reaquireSession() { // Store returns the session store func (m *Middleware) Store() *Store { - // TODO: Ensure that session.Save() can not be called - // on the store directly if the session is the same as the one in the middleware - // context. This is to prevent the session Save from invalidating the session. return m.config.Store } diff --git a/middleware/session/session.go b/middleware/session/session.go index a3b94bf6b0..919dfcda92 100644 --- a/middleware/session/session.go +++ b/middleware/session/session.go @@ -168,6 +168,10 @@ func (s *Session) Save() error { return nil } + return s.save() +} + +func (s *Session) save() error { // Better safe than sorry if s.data == nil { return nil From 7cb4a6e10abe5a3c283aa4e5b3d4dc31b7ca6832 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Sat, 15 Jun 2024 15:44:31 -0300 Subject: [PATCH 10/79] chore: Private get on store --- middleware/session/middleware.go | 2 +- middleware/session/store.go | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/middleware/session/middleware.go b/middleware/session/middleware.go index c168ae37e0..482420082a 100644 --- a/middleware/session/middleware.go +++ b/middleware/session/middleware.go @@ -55,7 +55,7 @@ func NewWithStore(config Config) (fiber.Handler, *Store) { } // Get the session - session, err := config.Store.Get(c) + session, err := config.Store.get(c) if err != nil { return err } diff --git a/middleware/session/store.go b/middleware/session/store.go index 85f955d751..d2d6d50caf 100644 --- a/middleware/session/store.go +++ b/middleware/session/store.go @@ -55,6 +55,10 @@ func (s *Store) Get(c fiber.Ctx) (*Session, error) { return nil, ErrSessionAlreadyLoadedByMiddleware } + return s.get(c) +} + +func (s *Store) get(c fiber.Ctx) (*Session, error) { // Get session based on context var fresh bool loadData := true From b4c8ea86a3049b480a3a5bcc1d6940bfc3de5fa0 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Sat, 15 Jun 2024 16:09:29 -0300 Subject: [PATCH 11/79] chore: Update session middleware to use saveSession instead of save --- middleware/session/middleware.go | 12 ++++++++---- middleware/session/session.go | 4 ++-- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/middleware/session/middleware.go b/middleware/session/middleware.go index 482420082a..25d2eec7ca 100644 --- a/middleware/session/middleware.go +++ b/middleware/session/middleware.go @@ -75,10 +75,10 @@ func NewWithStore(config Config) (fiber.Handler, *Store) { // Save the session // This is done after the response is sent to the client // It allows us to modify the session data during the request - // Without having to worry about calling Save() + // without having to worry about calling Save() on the session. // // It will also extend the session idle timeout automatically. - if err := session.save(); err != nil { + if err := session.saveSession(); err != nil { if config.ErrorHandler != nil { config.ErrorHandler(&c, err) } else { @@ -176,9 +176,13 @@ func (m *Middleware) reaquireSession() { return } - session, err := m.config.Store.Get(*m.ctx) + session, err := m.config.Store.get(*m.ctx) if err != nil { - m.config.ErrorHandler(m.ctx, err) + if m.config.ErrorHandler != nil { + m.config.ErrorHandler(m.ctx, err) + } else { + log.Errorf("session: %v", err) + } } m.Session = session m.hasChanged = false diff --git a/middleware/session/session.go b/middleware/session/session.go index 919dfcda92..f8ac9a85cb 100644 --- a/middleware/session/session.go +++ b/middleware/session/session.go @@ -168,10 +168,10 @@ func (s *Session) Save() error { return nil } - return s.save() + return s.saveSession() } -func (s *Session) save() error { +func (s *Session) saveSession() error { // Better safe than sorry if s.data == nil { return nil From aafee92650ac020768cd0fce074947232218cf83 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Sat, 15 Jun 2024 16:19:09 -0300 Subject: [PATCH 12/79] chore: Update session middleware to use getSession instead of get --- middleware/session/config.go | 14 ++++++++++++++ middleware/session/middleware.go | 8 ++++---- middleware/session/store.go | 4 ++-- 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/middleware/session/config.go b/middleware/session/config.go index a411e6e603..e99e819517 100644 --- a/middleware/session/config.go +++ b/middleware/session/config.go @@ -5,6 +5,7 @@ import ( "time" "github.com/gofiber/fiber/v3" + "github.com/gofiber/fiber/v3/log" "github.com/gofiber/utils/v2" ) @@ -92,6 +93,15 @@ var ConfigDefault = Config{ sessionName: "session_id", } +func DefaultErrorHandler(c *fiber.Ctx, err error) { + log.Errorf("session: %v", err) + if c != nil { + if err := (*c).SendStatus(fiber.StatusInternalServerError); err != nil { + log.Errorf("session: %v", err) + } + } +} + // Helper function to set default values func configDefault(config ...Config) Config { // Return default config if nothing provided @@ -130,5 +140,9 @@ func configDefault(config ...Config) Config { } cfg.sessionName = selectors[1] + if cfg.ErrorHandler == nil { + cfg.ErrorHandler = DefaultErrorHandler + } + return cfg } diff --git a/middleware/session/middleware.go b/middleware/session/middleware.go index 25d2eec7ca..e71e6dbbc2 100644 --- a/middleware/session/middleware.go +++ b/middleware/session/middleware.go @@ -55,7 +55,7 @@ func NewWithStore(config Config) (fiber.Handler, *Store) { } // Get the session - session, err := config.Store.get(c) + session, err := config.Store.getSession(c) if err != nil { return err } @@ -82,7 +82,7 @@ func NewWithStore(config Config) (fiber.Handler, *Store) { if config.ErrorHandler != nil { config.ErrorHandler(&c, err) } else { - log.Errorf("session: %v", err) + DefaultErrorHandler(&c, err) } } @@ -176,12 +176,12 @@ func (m *Middleware) reaquireSession() { return } - session, err := m.config.Store.get(*m.ctx) + session, err := m.config.Store.getSession(*m.ctx) if err != nil { if m.config.ErrorHandler != nil { m.config.ErrorHandler(m.ctx, err) } else { - log.Errorf("session: %v", err) + DefaultErrorHandler(m.ctx, err) } } m.Session = session diff --git a/middleware/session/store.go b/middleware/session/store.go index d2d6d50caf..9dcb3f4717 100644 --- a/middleware/session/store.go +++ b/middleware/session/store.go @@ -55,10 +55,10 @@ func (s *Store) Get(c fiber.Ctx) (*Session, error) { return nil, ErrSessionAlreadyLoadedByMiddleware } - return s.get(c) + return s.getSession(c) } -func (s *Store) get(c fiber.Ctx) (*Session, error) { +func (s *Store) getSession(c fiber.Ctx) (*Session, error) { // Get session based on context var fresh bool loadData := true From cd91db451976a907a813145193b56004eb5c12c2 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Sat, 15 Jun 2024 16:22:15 -0300 Subject: [PATCH 13/79] chore: Remove unused error handler in session middleware config --- middleware/session/config.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/middleware/session/config.go b/middleware/session/config.go index e99e819517..c28b44ed81 100644 --- a/middleware/session/config.go +++ b/middleware/session/config.go @@ -140,9 +140,5 @@ func configDefault(config ...Config) Config { } cfg.sessionName = selectors[1] - if cfg.ErrorHandler == nil { - cfg.ErrorHandler = DefaultErrorHandler - } - return cfg } From c3b303f92b2a4852dace4950837a432c7bc44e21 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Sat, 15 Jun 2024 16:25:03 -0300 Subject: [PATCH 14/79] chore: Update session middleware to use NewWithStore in CSRF tests --- middleware/csrf/csrf_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/middleware/csrf/csrf_test.go b/middleware/csrf/csrf_test.go index 70f7f032ee..fb1351bced 100644 --- a/middleware/csrf/csrf_test.go +++ b/middleware/csrf/csrf_test.go @@ -70,7 +70,7 @@ func Test_CSRF_WithSession(t *testing.T) { t.Parallel() // session store - store := session.New(session.Config{ + _, store := session.NewWithStore(session.Config{ KeyLookup: "cookie:_session", }) @@ -203,7 +203,7 @@ func Test_CSRF_ExpiredToken_WithSession(t *testing.T) { t.Parallel() // session store - store := session.New(session.Config{ + _, store := session.NewWithStore(session.Config{ KeyLookup: "cookie:_session", }) @@ -1072,7 +1072,7 @@ func Test_CSRF_DeleteToken_WithSession(t *testing.T) { t.Parallel() // session store - store := session.New(session.Config{ + _, store := session.NewWithStore(session.Config{ KeyLookup: "cookie:_session", }) From 2731428c3aa2002fb55ca57480ec222a319690d0 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Sat, 15 Jun 2024 17:29:06 -0300 Subject: [PATCH 15/79] test: add test --- middleware/session/middleware.go | 30 ++++++++++------ middleware/session/middleware_test.go | 52 +++++++++++++++++++++++++++ 2 files changed, 71 insertions(+), 11 deletions(-) create mode 100644 middleware/session/middleware_test.go diff --git a/middleware/session/middleware.go b/middleware/session/middleware.go index e71e6dbbc2..a68c0dde61 100644 --- a/middleware/session/middleware.go +++ b/middleware/session/middleware.go @@ -37,32 +37,40 @@ var ( // Session middleware manages common session state between requests. // This middleware is dependent on the session store, which is responsible for // storing the session data. -func New(config Config) fiber.Handler { - handler, _ := NewWithStore(config) +func New(config ...Config) fiber.Handler { + var handler fiber.Handler + if len(config) > 0 { + handler, _ = NewWithStore(config[0]) + } else { + handler, _ = NewWithStore() + } + return handler } // NewWithStore returns a new session middleware with the given store -func NewWithStore(config Config) (fiber.Handler, *Store) { - if config.Store == nil { - config.Store = newStore(config) +func NewWithStore(config ...Config) (fiber.Handler, *Store) { + cfg := configDefault(config...) + + if cfg.Store == nil { + cfg.Store = newStore(cfg) } handler := func(c fiber.Ctx) error { // Don't execute middleware if Next returns true - if config.Next != nil && config.Next(c) { + if cfg.Next != nil && cfg.Next(c) { return c.Next() } // Get the session - session, err := config.Store.getSession(c) + session, err := cfg.Store.getSession(c) if err != nil { return err } // get a middleware from the pool m := acquireMiddleware() - m.config = config + m.config = cfg m.Session = session m.ctx = &c @@ -79,8 +87,8 @@ func NewWithStore(config Config) (fiber.Handler, *Store) { // // It will also extend the session idle timeout automatically. if err := session.saveSession(); err != nil { - if config.ErrorHandler != nil { - config.ErrorHandler(&c, err) + if cfg.ErrorHandler != nil { + cfg.ErrorHandler(&c, err) } else { DefaultErrorHandler(&c, err) } @@ -92,7 +100,7 @@ func NewWithStore(config Config) (fiber.Handler, *Store) { return stackErr } - return handler, config.Store + return handler, cfg.Store } // acquireMiddleware returns a new Middleware from the pool diff --git a/middleware/session/middleware_test.go b/middleware/session/middleware_test.go new file mode 100644 index 0000000000..a948bab381 --- /dev/null +++ b/middleware/session/middleware_test.go @@ -0,0 +1,52 @@ +package session + +import ( + "strings" + "testing" + + "github.com/gofiber/fiber/v3" + "github.com/stretchr/testify/require" + "github.com/valyala/fasthttp" +) + +func TestNewWithStore(t *testing.T) { + t.Parallel() + app := fiber.New() + + app.Use(New()) + + app.Get("/", func(c fiber.Ctx) error { + sess := FromContext(c) + id := sess.ID() + return c.SendString("value=" + id) + }) + app.Post("/", func(c fiber.Ctx) error { + sess := FromContext(c) + id := sess.ID() + c.Cookie(&fiber.Cookie{ + Name: "session_id", + Value: id, + }) + return nil + }) + + h := app.Handler() + + // Test GET request without cookie + ctx := &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(fiber.MethodGet) + h(ctx) + require.Equal(t, 200, ctx.Response.StatusCode()) + // Get session cookie + token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) + token = strings.Split(strings.Split(token, ";")[0], "=")[1] + require.Equal(t, "value="+token, string(ctx.Response.Body())) + + // Test GET request with cookie + ctx = &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(fiber.MethodGet) + ctx.Request.Header.SetCookie("session_id", token) + h(ctx) + require.Equal(t, 200, ctx.Response.StatusCode()) + require.Equal(t, "value="+token, string(ctx.Response.Body())) +} From ee193dc88693cbae5bd5c2e58afc870a8c8e35cd Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Sat, 22 Jun 2024 17:35:11 -0300 Subject: [PATCH 16/79] fix: destroyed session and GHSA-98j2-3j3p-fw2v --- middleware/session/middleware.go | 45 ++---- middleware/session/middleware_test.go | 209 +++++++++++++++++++++++++- middleware/session/store.go | 50 +++--- 3 files changed, 244 insertions(+), 60 deletions(-) diff --git a/middleware/session/middleware.go b/middleware/session/middleware.go index a68c0dde61..b00d9b5c02 100644 --- a/middleware/session/middleware.go +++ b/middleware/session/middleware.go @@ -15,6 +15,7 @@ type Middleware struct { Session *Session ctx *fiber.Ctx hasChanged bool // TODO: use this to optimize interaction with the session store + destroyed bool mu sync.RWMutex } @@ -80,17 +81,20 @@ func NewWithStore(config ...Config) (fiber.Handler, *Store) { // Continue stack stackErr := c.Next() - // Save the session - // This is done after the response is sent to the client - // It allows us to modify the session data during the request - // without having to worry about calling Save() on the session. - // - // It will also extend the session idle timeout automatically. - if err := session.saveSession(); err != nil { - if cfg.ErrorHandler != nil { - cfg.ErrorHandler(&c, err) - } else { - DefaultErrorHandler(&c, err) + if !m.destroyed { + + // Save the session + // This is done after the response is sent to the client + // It allows us to modify the session data during the request + // without having to worry about calling Save() on the session. + // + // It will also extend the session idle timeout automatically. + if err := session.saveSession(); err != nil { + if cfg.ErrorHandler != nil { + cfg.ErrorHandler(&c, err) + } else { + DefaultErrorHandler(&c, err) + } } } @@ -158,7 +162,7 @@ func (m *Middleware) Destroy() error { defer m.mu.Unlock() err := m.Session.Destroy() - m.reaquireSession() + m.destroyed = true return err } @@ -179,23 +183,6 @@ func (m *Middleware) Reset() error { return err } -func (m *Middleware) reaquireSession() { - if m.ctx == nil { - return - } - - session, err := m.config.Store.getSession(*m.ctx) - if err != nil { - if m.config.ErrorHandler != nil { - m.config.ErrorHandler(m.ctx, err) - } else { - DefaultErrorHandler(m.ctx, err) - } - } - m.Session = session - m.hasChanged = false -} - // Store returns the session store func (m *Middleware) Store() *Store { return m.config.Store diff --git a/middleware/session/middleware_test.go b/middleware/session/middleware_test.go index a948bab381..0843c2be31 100644 --- a/middleware/session/middleware_test.go +++ b/middleware/session/middleware_test.go @@ -3,12 +3,217 @@ package session import ( "strings" "testing" + "time" "github.com/gofiber/fiber/v3" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" ) +func TestMiddleware(t *testing.T) { + t.Parallel() + app := fiber.New() + + app.Use(New()) + + app.Get("/get", func(c fiber.Ctx) error { + sess := FromContext(c) + if sess == nil { + return c.SendStatus(fiber.StatusInternalServerError) + } + value, ok := sess.Get("key").(string) + if !ok { + return c.Status(fiber.StatusNotFound).SendString("key not found") + } + return c.SendString("value=" + value) + }) + + app.Post("/set", func(c fiber.Ctx) error { + sess := FromContext(c) + if sess == nil { + return c.SendStatus(fiber.StatusInternalServerError) + } + // get a value from the body + value := c.FormValue("value") + sess.Set("key", value) + return c.SendStatus(fiber.StatusOK) + }) + + app.Post("/delete", func(c fiber.Ctx) error { + sess := FromContext(c) + if sess == nil { + return c.SendStatus(fiber.StatusInternalServerError) + } + sess.Delete("key") + return c.SendStatus(fiber.StatusOK) + }) + + app.Post("/reset", func(c fiber.Ctx) error { + sess := FromContext(c) + if sess == nil { + return c.SendStatus(fiber.StatusInternalServerError) + } + if err := sess.Reset(); err != nil { + return c.SendStatus(fiber.StatusInternalServerError) + } + return c.SendStatus(fiber.StatusOK) + }) + + app.Post("/destroy", func(c fiber.Ctx) error { + sess := FromContext(c) + if sess == nil { + return c.SendStatus(fiber.StatusInternalServerError) + } + if err := sess.Destroy(); err != nil { + return c.SendStatus(fiber.StatusInternalServerError) + } + return c.SendStatus(fiber.StatusOK) + }) + + app.Post("/fresh", func(c fiber.Ctx) error { + sess := FromContext(c) + if sess == nil { + return c.SendStatus(fiber.StatusInternalServerError) + } + // Reset the session to make it fresh + if err := sess.Reset(); err != nil { + return c.SendStatus(fiber.StatusInternalServerError) + } + if sess.Fresh() { + return c.SendStatus(fiber.StatusOK) + } + return c.SendStatus(fiber.StatusInternalServerError) + }) + + // Test GET, SET, DELETE, RESET, DESTROY by sending requests to the respective routes + ctx := &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(fiber.MethodGet) + ctx.Request.SetRequestURI("/get") + h := app.Handler() + h(ctx) + require.Equal(t, fiber.StatusNotFound, ctx.Response.StatusCode()) + token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) + token = strings.Split(strings.Split(token, ";")[0], "=")[1] + require.Equal(t, "key not found", string(ctx.Response.Body())) + + // Test POST /set + ctx.Request.Reset() + ctx.Response.Reset() + ctx.Request.Header.SetMethod(fiber.MethodPost) + ctx.Request.SetRequestURI("/set") + ctx.Request.Header.Set("Content-Type", "application/x-www-form-urlencoded") // Set the Content-Type + ctx.Request.SetBodyString("value=hello") + ctx.Request.Header.SetCookie("session_id", token) + h(ctx) + require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) + + // Test GET /get to check if the value was set + ctx.Request.Reset() + ctx.Response.Reset() + ctx.Request.Header.SetMethod(fiber.MethodGet) + ctx.Request.SetRequestURI("/get") + ctx.Request.Header.SetCookie("session_id", token) + h(ctx) + require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) + require.Equal(t, "value=hello", string(ctx.Response.Body())) + + // Test POST /delete to delete the value + ctx.Request.Reset() + ctx.Response.Reset() + ctx.Request.Header.SetMethod(fiber.MethodPost) + ctx.Request.SetRequestURI("/delete") + ctx.Request.Header.SetCookie("session_id", token) + h(ctx) + require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) + + // Test GET /get to check if the value was deleted + ctx.Request.Reset() + ctx.Response.Reset() + ctx.Request.Header.SetMethod(fiber.MethodGet) + ctx.Request.SetRequestURI("/get") + ctx.Request.Header.SetCookie("session_id", token) + h(ctx) + require.Equal(t, fiber.StatusNotFound, ctx.Response.StatusCode()) + require.Equal(t, "key not found", string(ctx.Response.Body())) + + // Test POST /reset to reset the session + ctx.Request.Reset() + ctx.Response.Reset() + ctx.Request.Header.SetMethod(fiber.MethodPost) + ctx.Request.SetRequestURI("/reset") + ctx.Request.Header.SetCookie("session_id", token) + h(ctx) + require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) + // verify we have a new session token + newToken := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) + newToken = strings.Split(strings.Split(newToken, ";")[0], "=")[1] + require.NotEqual(t, token, newToken) + token = newToken + + // Test POST /destroy to destroy the session + ctx.Request.Reset() + ctx.Response.Reset() + ctx.Request.Header.SetMethod(fiber.MethodPost) + ctx.Request.SetRequestURI("/destroy") + ctx.Request.Header.SetCookie("session_id", token) + h(ctx) + require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) + + // Verify the session cookie is set to expire + setCookieHeader := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) + require.Contains(t, setCookieHeader, "expires=") + cookieParts := strings.Split(setCookieHeader, ";") + expired := false + for _, part := range cookieParts { + if strings.Contains(part, "expires=") { + part = strings.TrimSpace(part) + expiryDateStr := strings.TrimPrefix(part, "expires=") + // Correctly parse the date with "GMT" timezone + expiryDate, err := time.Parse(time.RFC1123, strings.TrimSpace(expiryDateStr)) + require.NoError(t, err) + if expiryDate.Before(time.Now()) { + expired = true + break + } + } + } + require.True(t, expired, "Session cookie should be expired") + + // Sleep so that the session expires + time.Sleep(1 * time.Second) + + // Test GET /get to check if the session was destroyed + ctx.Request.Reset() + ctx.Response.Reset() + ctx.Request.Header.SetMethod(fiber.MethodGet) + ctx.Request.SetRequestURI("/get") + ctx.Request.Header.SetCookie("session_id", token) + h(ctx) + require.Equal(t, fiber.StatusNotFound, ctx.Response.StatusCode()) + // check that we have a new session token + newToken = string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) + parts := strings.Split(newToken, ";") + require.Greater(t, len(parts), 2) + valueParts := strings.Split(parts[0], "=") + require.Greater(t, len(valueParts), 1) + newToken = valueParts[1] + require.NotEqual(t, token, newToken) + token = newToken + + // Test POST /fresh to check if the session is fresh + ctx.Request.Reset() + ctx.Response.Reset() + ctx.Request.Header.SetMethod(fiber.MethodPost) + ctx.Request.SetRequestURI("/fresh") + ctx.Request.Header.SetCookie("session_id", token) + h(ctx) + require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) + // check that we have a new session token + newToken = string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) + newToken = strings.Split(strings.Split(newToken, ";")[0], "=")[1] + require.NotEqual(t, token, newToken) +} + func TestNewWithStore(t *testing.T) { t.Parallel() app := fiber.New() @@ -36,7 +241,7 @@ func TestNewWithStore(t *testing.T) { ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodGet) h(ctx) - require.Equal(t, 200, ctx.Response.StatusCode()) + require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) // Get session cookie token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] @@ -47,6 +252,6 @@ func TestNewWithStore(t *testing.T) { ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.Header.SetCookie("session_id", token) h(ctx) - require.Equal(t, 200, ctx.Response.StatusCode()) + require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) require.Equal(t, "value="+token, string(ctx.Response.Body())) } diff --git a/middleware/session/store.go b/middleware/session/store.go index 9dcb3f4717..fd9cdee76e 100644 --- a/middleware/session/store.go +++ b/middleware/session/store.go @@ -58,24 +58,28 @@ func (s *Store) Get(c fiber.Ctx) (*Session, error) { return s.getSession(c) } +// Get session based on context func (s *Store) getSession(c fiber.Ctx) (*Session, error) { - // Get session based on context var fresh bool - loadData := true + var rawData []byte + var err error id := s.getSessionID(c) - if len(id) == 0 { - fresh = true - var err error - if id, err = s.responseCookies(c); err != nil { + // Attempt to fetch session data if an ID is provided + if len(id) > 0 { + rawData, err = s.Storage.Get(id) + // If error is nil and raw is nil then token is not in storage + if rawData == nil && err == nil { + id = "" // Reset ID to generate a new one + } else if err != nil { return nil, err } } - // If no key exist, create new one - if len(id) == 0 { - loadData = false + // If no ID is provided or data not found in storage, generate a new ID + if len(id) == 0 || err != nil { + fresh = true id = s.KeyGenerator() } @@ -86,26 +90,14 @@ func (s *Store) getSession(c fiber.Ctx) (*Session, error) { sess.id = id sess.fresh = fresh - // Fetch existing data - if loadData { - raw, err := s.Storage.Get(id) - // Unmarshal if we found data - switch { - case err != nil: - return nil, err - - case raw != nil: - mux.Lock() - defer mux.Unlock() - sess.byteBuffer.Write(raw) - encCache := gob.NewDecoder(sess.byteBuffer) - err := encCache.Decode(&sess.data.Data) - if err != nil { - return nil, fmt.Errorf("failed to decode session data: %w", err) - } - default: - // both raw and err is nil, which means id is not in the storage - sess.fresh = true + // Decode session data if found + if rawData != nil { + mux.Lock() + defer mux.Unlock() + _, _ = sess.byteBuffer.Write(rawData) //nolint:errcheck // This will never fail + encCache := gob.NewDecoder(sess.byteBuffer) + if err := encCache.Decode(&sess.data.Data); err != nil { + return nil, fmt.Errorf("failed to decode session data: %w", err) } } From 1a5a3d7e1c65ad2f3212d40330f2a84b4e9baa98 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Mon, 29 Jul 2024 11:34:26 -0300 Subject: [PATCH 17/79] chore: Refactor session_test.go to use newStore() instead of New() --- middleware/session/session_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/middleware/session/session_test.go b/middleware/session/session_test.go index 9853c929d6..53516c96ce 100644 --- a/middleware/session/session_test.go +++ b/middleware/session/session_test.go @@ -863,7 +863,7 @@ func Benchmark_Session_Asserted_Parallel(b *testing.B) { func Test_Session_Concurrency(t *testing.T) { t.Parallel() app := fiber.New() - store := New() + store := newStore() var wg sync.WaitGroup errChan := make(chan error, 10) // Buffered channel to collect errors @@ -877,7 +877,7 @@ func Test_Session_Concurrency(t *testing.T) { localCtx := app.AcquireCtx(&fasthttp.RequestCtx{}) - sess, err := store.Get(localCtx) + sess, err := store.getSession(localCtx) if err != nil { errChan <- err return From 52e41a4d2544728791d3e61b47c8ee11504b551f Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Mon, 29 Jul 2024 11:47:15 -0300 Subject: [PATCH 18/79] feat: Improve session middleware test coverage and error handling This commit improves the session middleware test coverage by adding assertions for the presence of the Set-Cookie header and the token value. It also enhances error handling by checking for the expected number of parts in the Set-Cookie header. --- middleware/session/middleware_test.go | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/middleware/session/middleware_test.go b/middleware/session/middleware_test.go index 0843c2be31..196acc9ca0 100644 --- a/middleware/session/middleware_test.go +++ b/middleware/session/middleware_test.go @@ -93,7 +93,10 @@ func TestMiddleware(t *testing.T) { h(ctx) require.Equal(t, fiber.StatusNotFound, ctx.Response.StatusCode()) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) - token = strings.Split(strings.Split(token, ";")[0], "=")[1] + require.NotEmpty(t, token, "Expected Set-Cookie header to be present") + tokenParts := strings.SplitN(strings.SplitN(token, ";", 2)[0], "=", 2) + require.Equal(t, 2, len(tokenParts), "Expected Set-Cookie header to contain a token") + token = tokenParts[1] require.Equal(t, "key not found", string(ctx.Response.Body())) // Test POST /set @@ -146,7 +149,10 @@ func TestMiddleware(t *testing.T) { require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) // verify we have a new session token newToken := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) - newToken = strings.Split(strings.Split(newToken, ";")[0], "=")[1] + require.NotEmpty(t, newToken, "Expected Set-Cookie header to be present") + newTokenParts := strings.SplitN(strings.SplitN(newToken, ";", 2)[0], "=", 2) + require.Equal(t, 2, len(newTokenParts), "Expected Set-Cookie header to contain a token") + newToken = newTokenParts[1] require.NotEqual(t, token, newToken) token = newToken @@ -192,8 +198,9 @@ func TestMiddleware(t *testing.T) { require.Equal(t, fiber.StatusNotFound, ctx.Response.StatusCode()) // check that we have a new session token newToken = string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) + require.NotEmpty(t, newToken, "Expected Set-Cookie header to be present") parts := strings.Split(newToken, ";") - require.Greater(t, len(parts), 2) + require.Greater(t, len(parts), 1) valueParts := strings.Split(parts[0], "=") require.Greater(t, len(valueParts), 1) newToken = valueParts[1] @@ -210,7 +217,10 @@ func TestMiddleware(t *testing.T) { require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) // check that we have a new session token newToken = string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) - newToken = strings.Split(strings.Split(newToken, ";")[0], "=")[1] + require.NotEmpty(t, newToken, "Expected Set-Cookie header to be present") + newTokenParts = strings.SplitN(strings.SplitN(newToken, ";", 2)[0], "=", 2) + require.Equal(t, 2, len(newTokenParts), "Expected Set-Cookie header to contain a token") + newToken = newTokenParts[1] require.NotEqual(t, token, newToken) } From ed95d83b5eafa73f3391be4fb621915af721e1d7 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Mon, 29 Jul 2024 12:02:11 -0300 Subject: [PATCH 19/79] chore: fix lint issues --- middleware/session/config.go | 17 +++++++++-------- middleware/session/middleware.go | 5 ++--- middleware/session/middleware_test.go | 6 +++--- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/middleware/session/config.go b/middleware/session/config.go index 2a1ca9664a..51ac0ab373 100644 --- a/middleware/session/config.go +++ b/middleware/session/config.go @@ -11,6 +11,10 @@ import ( // Config defines the config for middleware. type Config struct { + // Storage interface to store the session data + // Optional. Default value memory.New() + Storage fiber.Storage + // Next defines a function to skip this middleware when returned true. // // Optional. Default: nil @@ -26,14 +30,6 @@ type Config struct { // Optional. Default: nil ErrorHandler func(*fiber.Ctx, error) - // Allowed session idle duration - // Optional. Default value 24 * time.Hour - IdleTimeout time.Duration - - // Storage interface to store the session data - // Optional. Default value memory.New() - Storage fiber.Storage - // KeyGenerator generates the session key. // Optional. Default value utils.UUIDv4 KeyGenerator func() string @@ -61,6 +57,11 @@ type Config struct { // The session name sessionName string + + // Allowed session idle duration + // Optional. Default value 24 * time.Hour + IdleTimeout time.Duration + // Allowed session duration // Optional. Default value 24 * time.Hour Expiration time.Duration diff --git a/middleware/session/middleware.go b/middleware/session/middleware.go index b00d9b5c02..591abb37fd 100644 --- a/middleware/session/middleware.go +++ b/middleware/session/middleware.go @@ -11,12 +11,12 @@ import ( // Session defines the session middleware configuration type Middleware struct { - config Config Session *Session ctx *fiber.Ctx + config Config + mu sync.RWMutex hasChanged bool // TODO: use this to optimize interaction with the session store destroyed bool - mu sync.RWMutex } // key for looking up session middleware in request context @@ -82,7 +82,6 @@ func NewWithStore(config ...Config) (fiber.Handler, *Store) { stackErr := c.Next() if !m.destroyed { - // Save the session // This is done after the response is sent to the client // It allows us to modify the session data during the request diff --git a/middleware/session/middleware_test.go b/middleware/session/middleware_test.go index 196acc9ca0..10fbde8301 100644 --- a/middleware/session/middleware_test.go +++ b/middleware/session/middleware_test.go @@ -95,7 +95,7 @@ func TestMiddleware(t *testing.T) { token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) require.NotEmpty(t, token, "Expected Set-Cookie header to be present") tokenParts := strings.SplitN(strings.SplitN(token, ";", 2)[0], "=", 2) - require.Equal(t, 2, len(tokenParts), "Expected Set-Cookie header to contain a token") + require.Len(t, tokenParts, 2, "Expected Set-Cookie header to contain a token") token = tokenParts[1] require.Equal(t, "key not found", string(ctx.Response.Body())) @@ -151,7 +151,7 @@ func TestMiddleware(t *testing.T) { newToken := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) require.NotEmpty(t, newToken, "Expected Set-Cookie header to be present") newTokenParts := strings.SplitN(strings.SplitN(newToken, ";", 2)[0], "=", 2) - require.Equal(t, 2, len(newTokenParts), "Expected Set-Cookie header to contain a token") + require.Len(t, newTokenParts, 2, "Expected Set-Cookie header to contain a token") newToken = newTokenParts[1] require.NotEqual(t, token, newToken) token = newToken @@ -219,7 +219,7 @@ func TestMiddleware(t *testing.T) { newToken = string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) require.NotEmpty(t, newToken, "Expected Set-Cookie header to be present") newTokenParts = strings.SplitN(strings.SplitN(newToken, ";", 2)[0], "=", 2) - require.Equal(t, 2, len(newTokenParts), "Expected Set-Cookie header to contain a token") + require.Len(t, newTokenParts, 2, "Expected Set-Cookie header to contain a token") newToken = newTokenParts[1] require.NotEqual(t, token, newToken) } From c6e1c344104e39d415343f4a407ebed40b7688ac Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Mon, 29 Jul 2024 12:23:09 -0300 Subject: [PATCH 20/79] chore: Fix session middleware locking issue and improve error handling --- middleware/session/middleware.go | 12 +++++++++++- middleware/session/middleware_test.go | 5 ++++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/middleware/session/middleware.go b/middleware/session/middleware.go index 591abb37fd..0c6c858b88 100644 --- a/middleware/session/middleware.go +++ b/middleware/session/middleware.go @@ -71,17 +71,23 @@ func NewWithStore(config ...Config) (fiber.Handler, *Store) { // get a middleware from the pool m := acquireMiddleware() + m.mu.Lock() m.config = cfg m.Session = session m.ctx = &c // Store the middleware in the context c.Locals(key, m) + m.mu.Unlock() // Continue stack stackErr := c.Next() - if !m.destroyed { + m.mu.RLock() + destroyed := m.destroyed + m.mu.RUnlock() + + if !destroyed { // Save the session // This is done after the response is sent to the client // It allows us to modify the session data during the request @@ -117,9 +123,13 @@ func acquireMiddleware() *Middleware { // releaseMiddleware returns a Middleware to the pool func releaseMiddleware(m *Middleware) { + m.mu.Lock() m.config = Config{} m.Session = nil m.ctx = nil + m.destroyed = false + m.hasChanged = false + m.mu.Unlock() middlewarePool.Put(m) } diff --git a/middleware/session/middleware_test.go b/middleware/session/middleware_test.go index 10fbde8301..ed1e581242 100644 --- a/middleware/session/middleware_test.go +++ b/middleware/session/middleware_test.go @@ -254,7 +254,10 @@ func TestNewWithStore(t *testing.T) { require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) // Get session cookie token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) - token = strings.Split(strings.Split(token, ";")[0], "=")[1] + require.NotEmpty(t, token, "Expected Set-Cookie header to be present") + tokenParts := strings.SplitN(strings.SplitN(token, ";", 2)[0], "=", 2) + require.Len(t, tokenParts, 2, "Expected Set-Cookie header to contain a token") + token = tokenParts[1] require.Equal(t, "value="+token, string(ctx.Response.Body())) // Test GET request with cookie From 8a5663aabb8b4844d61201f31e270a6a81d36d54 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Sat, 3 Aug 2024 15:53:46 -0300 Subject: [PATCH 21/79] test: improve middleware test coverage and error handling --- middleware/session/middleware_test.go | 113 +++++++++++++++++++++++++- 1 file changed, 112 insertions(+), 1 deletion(-) diff --git a/middleware/session/middleware_test.go b/middleware/session/middleware_test.go index ed1e581242..d694be9197 100644 --- a/middleware/session/middleware_test.go +++ b/middleware/session/middleware_test.go @@ -10,7 +10,7 @@ import ( "github.com/valyala/fasthttp" ) -func TestMiddleware(t *testing.T) { +func Test_Session_Middleware(t *testing.T) { t.Parallel() app := fiber.New() @@ -268,3 +268,114 @@ func TestNewWithStore(t *testing.T) { require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) require.Equal(t, "value="+token, string(ctx.Response.Body())) } + +func Test_Session_FromSession(t *testing.T) { + t.Parallel() + app := fiber.New() + + sess := FromContext(app.AcquireCtx(&fasthttp.RequestCtx{})) + require.Nil(t, sess) + + app.Use(New()) +} + +func Test_Session_WithConfig(t *testing.T) { + t.Parallel() + app := fiber.New() + + app.Use(New(Config{ + Next: func(c fiber.Ctx) bool { + return c.Get("key") == "value" + }, + IdleTimeout: 1 * time.Second, + KeyLookup: "cookie:session_id_test", + KeyGenerator: func() string { + return "test" + }, + source: "cookie_test", + sessionName: "session_id_test", + })) + + app.Get("/", func(c fiber.Ctx) error { + sess := FromContext(c) + id := sess.ID() + return c.SendString("value=" + id) + }) + + app.Get("/isFresh", func(c fiber.Ctx) error { + sess := FromContext(c) + if sess.Fresh() { + return c.SendStatus(fiber.StatusOK) + } + return c.SendStatus(fiber.StatusInternalServerError) + }) + + app.Post("/", func(c fiber.Ctx) error { + sess := FromContext(c) + id := sess.ID() + c.Cookie(&fiber.Cookie{ + Name: "session_id_test", + Value: id, + }) + return nil + }) + + h := app.Handler() + + // Test GET request without cookie + ctx := &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(fiber.MethodGet) + h(ctx) + require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) + // Get session cookie + token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) + require.NotEmpty(t, token, "Expected Set-Cookie header to be present") + tokenParts := strings.SplitN(strings.SplitN(token, ";", 2)[0], "=", 2) + require.Len(t, tokenParts, 2, "Expected Set-Cookie header to contain a token") + token = tokenParts[1] + require.Equal(t, "value="+token, string(ctx.Response.Body())) + + // Test GET request with cookie + ctx = &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(fiber.MethodGet) + ctx.Request.Header.SetCookie("session_id_test", token) + h(ctx) + require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) + require.Equal(t, "value="+token, string(ctx.Response.Body())) + + // Test POST request with cookie + ctx = &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(fiber.MethodPost) + ctx.Request.Header.SetCookie("session_id_test", token) + h(ctx) + require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) + + // Test POST request without cookie + ctx = &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(fiber.MethodPost) + h(ctx) + require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) + + // Test POST request with wrong key + ctx = &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(fiber.MethodPost) + ctx.Request.Header.SetCookie("session_id", token) + h(ctx) + require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) + + // Test POST request with wrong value + ctx = &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(fiber.MethodPost) + ctx.Request.Header.SetCookie("session_id_test", "wrong") + h(ctx) + require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) + + // Test idle timeout + time.Sleep(1200 * time.Millisecond) + ctx = &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(fiber.MethodGet) + ctx.Request.Header.SetCookie("session_id_test", token) + ctx.Request.SetRequestURI("/isFresh") + h(ctx) + require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) +} From 46845e6b78c56bc0b48c61d6992f4b67efeed849 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Sat, 3 Aug 2024 15:54:48 -0300 Subject: [PATCH 22/79] test: Add idle timeout test case to session middleware test --- middleware/session/middleware_test.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/middleware/session/middleware_test.go b/middleware/session/middleware_test.go index d694be9197..31e38d04a0 100644 --- a/middleware/session/middleware_test.go +++ b/middleware/session/middleware_test.go @@ -370,6 +370,14 @@ func Test_Session_WithConfig(t *testing.T) { h(ctx) require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) + // Check idle timeout not expired + ctx = &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(fiber.MethodGet) + ctx.Request.Header.SetCookie("session_id_test", token) + ctx.Request.SetRequestURI("/isFresh") + h(ctx) + require.Equal(t, fiber.StatusInternalServerError, ctx.Response.StatusCode()) + // Test idle timeout time.Sleep(1200 * time.Millisecond) ctx = &fasthttp.RequestCtx{} From ba0e49176a9ecb37546f2c33378c43a682ceb065 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Sat, 10 Aug 2024 13:36:51 -0300 Subject: [PATCH 23/79] feat: add GetSession(id string) (*Session, error) --- middleware/session/middleware.go | 3 ++ middleware/session/middleware_test.go | 56 ++++++++++++++++++++++++++- middleware/session/session.go | 18 +++++++-- middleware/session/store.go | 50 ++++++++++++++++++++++++ 4 files changed, 122 insertions(+), 5 deletions(-) diff --git a/middleware/session/middleware.go b/middleware/session/middleware.go index 0c6c858b88..f92de96abe 100644 --- a/middleware/session/middleware.go +++ b/middleware/session/middleware.go @@ -101,6 +101,9 @@ func NewWithStore(config ...Config) (fiber.Handler, *Store) { DefaultErrorHandler(&c, err) } } + + // Release the session back to the pool + releaseSession(session) } // release the middleware back to the pool diff --git a/middleware/session/middleware_test.go b/middleware/session/middleware_test.go index 31e38d04a0..c5020b6060 100644 --- a/middleware/session/middleware_test.go +++ b/middleware/session/middleware_test.go @@ -2,6 +2,7 @@ package session import ( "strings" + "sync" "testing" "time" @@ -224,7 +225,7 @@ func Test_Session_Middleware(t *testing.T) { require.NotEqual(t, token, newToken) } -func TestNewWithStore(t *testing.T) { +func Test_Session_NewWithStore(t *testing.T) { t.Parallel() app := fiber.New() @@ -387,3 +388,56 @@ func Test_Session_WithConfig(t *testing.T) { h(ctx) require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) } + +func Test_Session_Next(t *testing.T) { + t.Parallel() + + var ( + doNext bool + muNext sync.RWMutex + ) + + app := fiber.New() + + app.Use(New(Config{ + Next: func(c fiber.Ctx) bool { + muNext.RLock() + defer muNext.RUnlock() + return doNext + }, + })) + + app.Get("/", func(c fiber.Ctx) error { + sess := FromContext(c) + if sess == nil { + return c.SendStatus(fiber.StatusInternalServerError) + } + id := sess.ID() + return c.SendString("value=" + id) + }) + + h := app.Handler() + + // Test with Next returning false + ctx := &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(fiber.MethodGet) + h(ctx) + require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) + // Get session cookie + token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) + require.NotEmpty(t, token, "Expected Set-Cookie header to be present") + tokenParts := strings.SplitN(strings.SplitN(token, ";", 2)[0], "=", 2) + require.Len(t, tokenParts, 2, "Expected Set-Cookie header to contain a token") + token = tokenParts[1] + require.Equal(t, "value="+token, string(ctx.Response.Body())) + + // Test with Next returning true + muNext.Lock() + doNext = true + muNext.Unlock() + + ctx = &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(fiber.MethodGet) + h(ctx) + require.Equal(t, fiber.StatusInternalServerError, ctx.Response.StatusCode()) +} diff --git a/middleware/session/session.go b/middleware/session/session.go index ecea2d324d..2cb0ee9083 100644 --- a/middleware/session/session.go +++ b/middleware/session/session.go @@ -41,6 +41,20 @@ func acquireSession() *Session { return s } +// Release releases the session back to the pool. +// +// This function should be called after the session is no longer needed. +// This function is used to reduce the number of allocations and +// to improve the performance of the session store. +// +// The session should not be used after calling this function. +func (sess *Session) Release() { + if sess == nil { + return + } + releaseSession(sess) +} + func releaseSession(s *Session) { s.mu.Lock() s.id = "" @@ -224,10 +238,6 @@ func (s *Session) saveSession() error { s.mu.Unlock() - // Release session - // TODO: It's not safe to use the Session after calling Save() - releaseSession(s) - return nil } diff --git a/middleware/session/store.go b/middleware/session/store.go index 925ffa1f26..f023f9869a 100644 --- a/middleware/session/store.go +++ b/middleware/session/store.go @@ -152,3 +152,53 @@ func (s *Store) Delete(id string) error { } return s.Storage.Delete(id) } + +// GetSession retrieves a session by its ID from the storage. +// If the session is not found, it returns nil and an error. +// +// Note: +// - Unlike session Middleware methods, Session methods do not automatically: +// - Load the session into the context +// - Save the session data to the storage and update the client cookie +// +// - Be aware of possible collisions if you are also using the session in a middleware. +// +// Usage: +// - If you modify a session returned by GetSession, you must call session.Save() to persist the changes. +// - When you are done with the session, you should call session.Release() to release the session back to the pool. +// +// Parameters: +// - id: The unique identifier of the session. +// +// Returns: +// - *Session: The session object if found, otherwise nil. +// - error: An error if the session retrieval fails or if the session ID is empty. +func (s *Store) GetSession(id string) (*Session, error) { + if id == "" { + return nil, ErrEmptySessionID + } + + rawData, err := s.Storage.Get(id) + if err != nil { + return nil, err + } + if rawData == nil { + return nil, nil + } + + sess := acquireSession() + + sess.mu.Lock() + defer sess.mu.Unlock() + + sess.id = id + sess.config = s + + sess.data.Lock() + defer sess.data.Unlock() + if err := sess.decodeSessionData(rawData); err != nil { + return nil, fmt.Errorf("failed to decode session data: %w", err) + } + + return sess, nil +} From d08b686eb9ef9130c1ca7422bd4771db43026c98 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Sat, 10 Aug 2024 14:03:24 -0300 Subject: [PATCH 24/79] chore: lint --- middleware/session/middleware_test.go | 2 +- middleware/session/session.go | 6 +++--- middleware/session/store.go | 7 ++++--- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/middleware/session/middleware_test.go b/middleware/session/middleware_test.go index c5020b6060..6855ae9d8d 100644 --- a/middleware/session/middleware_test.go +++ b/middleware/session/middleware_test.go @@ -400,7 +400,7 @@ func Test_Session_Next(t *testing.T) { app := fiber.New() app.Use(New(Config{ - Next: func(c fiber.Ctx) bool { + Next: func(_ fiber.Ctx) bool { muNext.RLock() defer muNext.RUnlock() return doNext diff --git a/middleware/session/session.go b/middleware/session/session.go index 2cb0ee9083..7834ecca4b 100644 --- a/middleware/session/session.go +++ b/middleware/session/session.go @@ -48,11 +48,11 @@ func acquireSession() *Session { // to improve the performance of the session store. // // The session should not be used after calling this function. -func (sess *Session) Release() { - if sess == nil { +func (s *Session) Release() { + if s == nil { return } - releaseSession(sess) + releaseSession(s) } func releaseSession(s *Session) { diff --git a/middleware/session/store.go b/middleware/session/store.go index f023f9869a..3ddf28bcc7 100644 --- a/middleware/session/store.go +++ b/middleware/session/store.go @@ -14,6 +14,7 @@ import ( var ( ErrEmptySessionID = errors.New("session id cannot be empty") ErrSessionAlreadyLoadedByMiddleware = errors.New("session already loaded by middleware") + ErrSessionIDNotFoundInStore = errors.New("session ID not found in session store") ) // sessionIDKey is the local key type used to store and retrieve the session ID in context. @@ -153,7 +154,7 @@ func (s *Store) Delete(id string) error { return s.Storage.Delete(id) } -// GetSession retrieves a session by its ID from the storage. +// GetSessionByID retrieves a session by its ID from the storage. // If the session is not found, it returns nil and an error. // // Note: @@ -173,7 +174,7 @@ func (s *Store) Delete(id string) error { // Returns: // - *Session: The session object if found, otherwise nil. // - error: An error if the session retrieval fails or if the session ID is empty. -func (s *Store) GetSession(id string) (*Session, error) { +func (s *Store) GetSessionByID(id string) (*Session, error) { if id == "" { return nil, ErrEmptySessionID } @@ -183,7 +184,7 @@ func (s *Store) GetSession(id string) (*Session, error) { return nil, err } if rawData == nil { - return nil, nil + return nil, ErrSessionIDNotFoundInStore } sess := acquireSession() From c08ddc199e6ed7b558d7b40a946d145ec6938729 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Sun, 8 Sep 2024 15:03:01 -0300 Subject: [PATCH 25/79] docs: Update session middleware docs --- docs/middleware/session.md | 405 ++++++++++++++++++++++++------- middleware/session/config.go | 24 +- middleware/session/data.go | 57 +++++ middleware/session/data_msgp.go | 372 ++++++++++++++++------------ middleware/session/middleware.go | 129 +++++++++- middleware/session/session.go | 128 +++++++++- middleware/session/store.go | 81 ++++++- 7 files changed, 922 insertions(+), 274 deletions(-) diff --git a/docs/middleware/session.md b/docs/middleware/session.md index 39b9ccc801..4d5f63d76c 100644 --- a/docs/middleware/session.md +++ b/docs/middleware/session.md @@ -2,142 +2,373 @@ id: session --- -# Session +# Session Middleware for [Fiber](https://github.com/gofiber/fiber) -Session middleware for [Fiber](https://github.com/gofiber/fiber). +The `session` middleware provides session handling for Fiber applications. It leverages the [Storage](https://github.com/gofiber/storage) package to offer support for multiple databases through a unified interface. By default, session data is stored in memory, but you can easily switch to other storage options, as shown in the examples below. :::note -This middleware uses our [Storage](https://github.com/gofiber/storage) package to support various databases through a single interface. The default configuration for this middleware saves data to memory, see the examples below for other databases. +We recommend using the `Middleware` handler for better integration with other middleware. See the [As a Middleware Handler (Recommended)](#as-a-middleware-handler-recommended) section for details. ::: +## Table of Contents + +- [Migration Guide](#migration-guide) + - [v2 to v3](#v2-to-v3) +- [Types](#types) + - [Config](#config) + - [Middleware](#middleware) + - [Session](#session) + - [Store](#store) +- [Signatures](#signatures) + - [Session Package Functions](#session-package-functions) + - [Config Methods](#config-methods) + - [Middleware Methods](#middleware-methods) + - [Session Methods](#session-methods) + - [Store Methods](#store-methods) +- [Examples](#examples) + - [As a Middleware Handler (Recommended)](#as-a-middleware-handler-recommended) + - [Using a Custom Storage](#using-a-custom-storage) + - [Session without Middleware Handler](#session-without-middleware-handler) + - [Using Custom Types in Session Data](#using-custom-types-in-session-data) +- [Config](#config) +- [Default Config](#default-config) + +## Migration Guide + +### v2 to v3 + +- The `New` function signature has changed in v3. It now returns a `*Middleware` instead of a `*Store`. You can access the store using the `Store` method on the `*Middleware` or by using the `NewWithStore` function. + +While it's still possible to work with the `*Store` directly, we recommend using the `Middleware` handler for better integration with other Fiber middlewares. + +For more information about changes in Fiber v3, see [What's New](https://github.com/gofiber/fiber/blob/main/docs/whats_new.md). + +#### v2 Example + +```go +store := session.New() + +app.Get("/", func(c *fiber.Ctx) error { + sess, err := store.Get(c) + if err != nil { + return err + } + + key, ok := sess.Get("key").(string) + if !ok { + return c.SendStatus(fiber.StatusInternalServerError) + } + + sess.Set("key", "value") + + err = sess.Save() + if err != nil { + return c.SendStatus(fiber.StatusInternalServerError) + } + + return nil +}) +``` + +#### v3 Example (Using Store) + +```go +_, store := session.NewWithStore() + +app.Get("/", func(c *fiber.Ctx) error { + sess, err := store.Get(c) + if err != nil { + return err + } + + key, ok := sess.Get("key").(string) + if !ok { + return c.SendStatus(fiber.StatusInternalServerError) + } + + sess.Set("key", "value") + + err = sess.Save() + if err != nil { + return c.SendStatus(fiber.StatusInternalServerError) + } + + return nil +}) +``` + +#### v3 Example (Using Middleware) + +See the [As a Middleware Handler (Recommended)](#as-a-middleware-handler-recommended) section for details. + +## Types + +### Config + +The configuration for the session middleware. + +```go +type Config struct { + Storage fiber.Storage + Next func(c *fiber.Ctx) bool + Store *Store + ErrorHandler func(*fiber.Ctx, error) + KeyGenerator func() string + KeyLookup string + CookieDomain string + CookiePath string + CookieSameSite string + IdleTimeout time.Duration + Expiration time.Duration + CookieSecure bool + CookieHTTPOnly bool + CookieSessionOnly bool +} +``` + +### Middleware + +The `Middleware` struct encapsulates the session middleware configuration and storage. It is created using the `New` or `NewWithStorage` function and used as a `fiber.Handler`. + +```go +type Middleware struct { + Session *Session +} +``` + +### Session + +The `Session` struct is used to interact with session data. You can retrieve it from the `Middleware` using the `FromContext` method or from the `Store` using the `Get` method. + +```go +type Session struct {} +``` + +### Store + +The `Store` struct is used to manage session data. It is created using the `NewWithStore` function or by calling the `Store` method on a `Middleware`. + +```go +type Store struct { + Config +} +``` + ## Signatures +### Session Package Functions + ```go -func New(config ...Config) *Store -func (s *Store) RegisterType(i any) -func (s *Store) Get(c fiber.Ctx) (*Session, error) -func (s *Store) Delete(id string) error -func (s *Store) Reset() error +func New(config ...Config) *Middleware +func NewWithStore(config ...Config) (fiber.Handler, *Store) +func FromContext(c fiber.Ctx) *Middleware +``` + +### Config Methods + +```go +func DefaultErrorHandler(c *fiber.Ctx, err error) +``` + +### Middleware Methods +Used to interact with session data when using the middleware handler. + +```go +func (m *Middleware) Set(key string, value any) +func (m *Middleware) Get(key string) any +func (m *Middleware) Delete(key string) +func (m *Middleware) Destroy() error +func (m *Middleware) Reset() error +func (m *Middleware) Store() *Store +``` + +### Session Methods + +If using the middleware handler, you generally won't need to use these methods directly. + +```go +func (s *Session) Fresh() bool +func (s *Session) ID() string func (s *Session) Get(key string) any func (s *Session) Set(key string, val any) -func (s *Session) Delete(key string) func (s *Session) Destroy() error -func (s *Session) Reset() error func (s *Session) Regenerate() error +func (s *Session) Reset() error func (s *Session) Save() error -func (s *Session) Fresh() bool -func (s *Session) ID() string func (s *Session) Keys() []string -func (s *Session) SetExpiry(exp time.Duration) +func (s *Session) SetIdleTimeout(idleTimeout time.Duration) ``` -:::caution -Storing `any` values are limited to built-ins Go types. -::: +### Store Methods + +```go +func (*Store) RegisterType(i any) +func (s *Store) Get(c fiber.Ctx) (*Session, error) +func (s *Store) Reset() error +func (s *Store) Delete(id string) error +func (s *Store) GetSessionByID(id string) (*Session, error) +``` ## Examples -Import the middleware package that is part of the Fiber web framework +### As a Middleware Handler (Recommended) ```go +package main + import ( + "fmt" + "log" + "github.com/gofiber/fiber/v3" - "github.com/gofiber/fiber/v3/middleware/session" + "github.com/gofiber/session/v3" + "github.com/gofiber/session/v3/middleware/csrf" ) -``` -After you initiate your Fiber app, you can use the following possibilities: +func main() { + app := fiber.New() -```go -// Initialize default config -// This stores all of your app's sessions -store := session.New() + sessionMiddleware, sessionStore := session.NewWithStore() -app.Get("/", func(c fiber.Ctx) error { - // Get session from storage - sess, err := store.Get(c) - if err != nil { - panic(err) - } + app.Use(sessionMiddleware) - // Get value - name := sess.Get("name") + app.Use(csrf.New(csrf.Config{ + Store: sessionStore, + })) - // Set key/value - sess.Set("name", "john") + app.Get("/", func(c *fiber.Ctx) error { + sess := session.FromContext(c) + if sess == nil { + return c.SendStatus(fiber.StatusInternalServerError) + } - // Get all Keys - keys := sess.Keys() + name, ok := sess.Get("name").(string) + if !ok { + return c.SendString("Welcome anonymous user!") + } - // Delete key - sess.Delete("name") + return c.SendString(fmt.Sprintf("Welcome %v", name)) + }) - // Destroy session - if err := sess.Destroy(); err != nil { - panic(err) - } + log.Fatal(app.Listen(":3000")) +} +``` - // Sets a specific expiration for this session - sess.SetExpiry(time.Second * 2) +### Using a Custom Storage - // Save session - if err := sess.Save(); err != nil { - panic(err) - } +This example shows how to use the `sqlite3` storage from the [Fiber storage package](https://github.com/gofiber/storage). - return c.SendString(fmt.Sprintf("Welcome %v", name)) -}) -``` +```go +package main -## Config +import ( + "log" -| Property | Type | Description | Default | -|:------------------------|:----------------|:------------------------------------------------------------------------------------------------------------|:----------------------| -| Expiration | `time.Duration` | Allowed session duration. | `24 * time.Hour` | -| Storage | `fiber.Storage` | Storage interface to store the session data. | `memory.New()` | -| KeyLookup | `string` | KeyLookup is a string in the form of "`:`" that is used to extract session id from the request. | `"cookie:session_id"` | -| CookieDomain | `string` | Domain of the cookie. | `""` | -| CookiePath | `string` | Path of the cookie. | `""` | -| CookieSecure | `bool` | Indicates if cookie is secure. | `false` | -| CookieHTTPOnly | `bool` | Indicates if cookie is HTTP only. | `false` | -| CookieSameSite | `string` | Value of SameSite cookie. | `"Lax"` | -| CookieSessionOnly | `bool` | Decides whether cookie should last for only the browser session. Ignores Expiration if set to true. | `false` | -| KeyGenerator | `func() string` | KeyGenerator generates the session key. | `utils.UUIDv4` | -| CookieName (Deprecated) | `string` | Deprecated: Please use KeyLookup. The session name. | `""` | + "github.com/gofiber/fiber/v3" + "github.com/gofiber/storage/sqlite3" + "github.com/gofiber/session/v3" + "github.com/gofiber/session/v3/middleware/csrf" +) -## Default Config +func main() { + app := fiber.New() -```go -var ConfigDefault = Config{ - Expiration: 24 * time.Hour, - KeyLookup: "cookie:session_id", - KeyGenerator: utils.UUIDv4, - source: "cookie", - sessionName: "session_id", + storage := sqlite3.New() + sessionMiddleware, sessionStore := session.NewWithStore(session.Config{ + Storage: storage, + }) + + app.Use(sessionMiddleware) + + app.Use(csrf.New(csrf.Config{ + Store: sessionStore, + })) + + log.Fatal(app.Listen(":3000")) } ``` -## Constants +### Session without Middleware Handler + +This example shows how to work with sessions directly without the middleware handler. ```go -const ( - SourceCookie Source = "cookie" - SourceHeader Source = "header" - SourceURLQuery Source = "query" +package main + +import ( + "log" + + "github.com/gofiber/fiber/v3" + "github.com/gofiber/session/v3" + "github.com/gofiber/session/v3/middleware/csrf" ) -``` -### Custom Storage/Database +func main() { + app := fiber.New() -You can use any storage from our [storage](https://github.com/gofiber/storage/) package. + _, sessionStore := session.NewWithStore() -```go -storage := sqlite3.New() // From github.com/gofiber/storage/sqlite3 + app.Use(csrf.New(csrf.Config{ + Store: sessionStore, + })) -store := session.New(session.Config{ - Storage: storage, -}) + app.Get("/", func(c *fiber.Ctx) error { + sess, err := sessionStore.Get(c) + if err != nil { + return c.SendStatus(fiber.StatusInternalServerError) + } + + name, ok := sess.Get("name").(string) + if !ok { + return c.SendString("Welcome anonymous user!") + } + + return c.SendString(fmt.Sprintf("Welcome %v", name)) + }) + + log.Fatal(app.Listen(":3000 + +")) +} ``` -To use the store, see the [Examples](#examples). +## Config + +| Property | Type | Description | Default | +|:------------------------|:----------------|:--------------------------------------------------------------------------------------------------------------|:----------------------| +| Storage | `fiber.Storage` | Storage interface to store the session data. | `memory.New()` | +| Next | `func(c fiber.Ctx) bool` | Function to skip this middleware when returned true. | `nil` | +| Store | `*Store` | Defines the session store. | `nil` (Required) | +| ErrorHandler | `func(*fiber.Ctx, error)` | Function executed for errors. | `nil` | +| KeyGenerator | `func() string` | KeyGenerator generates the session key. | `utils.UUIDv4` | +| KeyLookup | `string` | KeyLookup is a string in the form of "`:`" that is used to extract session id from the request. | `"cookie:session_id"` | +| CookieDomain | `string` | Domain of the cookie. | `""` | +| CookiePath | `string` | Path of the cookie. | `""` | +| CookieSameSite | `string` | Value of SameSite cookie. | `"Lax"` | +| IdleTimeout | `time.Duration` | Allowed session idle duration. | `24 * time.Hour` | +| Expiration | `time.Duration` | Allowed session duration. | `24 * time.Hour` | +| CookieSecure | `bool` | Indicates if cookie is secure. | `false` | +| CookieHTTPOnly | `bool` | Indicates if cookie is HTTP only. | `false` | +| CookieSessionOnly | `bool` | Decides whether cookie should last for only the browser session. Ignores Expiration if set to true. | `false` | + +## Default Config + +```go +session.Config{ + Storage: memory.New(), + Next: nil, + Store: nil, + ErrorHandler: nil, + KeyGenerator: utils.UUIDv4, + KeyLookup: "cookie:session_id", + CookieDomain: "", + CookiePath: "", + CookieSameSite: "Lax", + IdleTimeout: 24 * time.Hour, + Expiration: 24 * time.Hour, + CookieSecure: false, + CookieHTTPOnly: false, + CookieSessionOnly: false, +} +``` \ No newline at end of file diff --git a/middleware/session/config.go b/middleware/session/config.go index 51ac0ab373..ed3e557b3b 100644 --- a/middleware/session/config.go +++ b/middleware/session/config.go @@ -74,7 +74,7 @@ type Config struct { // Optional. Default value false. CookieHTTPOnly bool - // Decides whether cookie should last for only the browser sesison. + // Decides whether cookie should last for only the browser session. // Ignores Expiration if set to true // Optional. Default value false. CookieSessionOnly bool @@ -97,6 +97,15 @@ var ConfigDefault = Config{ sessionName: "session_id", } +// DefaultErrorHandler logs the error and sends a 500 status code. +// +// Parameters: +// - c: The Fiber context. +// - err: The error to handle. +// +// Usage: +// +// DefaultErrorHandler(c, err) func DefaultErrorHandler(c *fiber.Ctx, err error) { log.Errorf("session: %v", err) if c != nil { @@ -106,7 +115,18 @@ func DefaultErrorHandler(c *fiber.Ctx, err error) { } } -// Helper function to set default values +// configDefault sets default values for the Config struct. +// +// Parameters: +// - config: Variadic parameter to override default config. +// +// Returns: +// - Config: The configuration with default values set. +// +// Usage: +// +// cfg := configDefault() +// cfg := configDefault(customConfig) func configDefault(config ...Config) Config { // Return default config if nothing provided if len(config) < 1 { diff --git a/middleware/session/data.go b/middleware/session/data.go index 08cb833f4e..81732babcc 100644 --- a/middleware/session/data.go +++ b/middleware/session/data.go @@ -20,16 +20,40 @@ var dataPool = sync.Pool{ }, } +// acquireData returns a new data object from the pool. +// +// Returns: +// - *data: The data object. +// +// Usage: +// +// d := acquireData() func acquireData() *data { return dataPool.Get().(*data) //nolint:forcetypeassert // We store nothing else in the pool } +// Reset clears the data map and resets the data object. +// +// Usage: +// +// d.Reset() func (d *data) Reset() { d.Lock() d.Data = make(map[string]any) d.Unlock() } +// Get retrieves a value from the data map by key. +// +// Parameters: +// - key: The key to retrieve. +// +// Returns: +// - any: The value associated with the key. +// +// Usage: +// +// value := d.Get("key") func (d *data) Get(key string) any { d.RLock() v := d.Data[key] @@ -37,18 +61,43 @@ func (d *data) Get(key string) any { return v } +// Set updates or creates a new key-value pair in the data map. +// +// Parameters: +// - key: The key to set. +// - value: The value to set. +// +// Usage: +// +// d.Set("key", "value") func (d *data) Set(key string, value any) { d.Lock() d.Data[key] = value d.Unlock() } +// Delete removes a key-value pair from the data map. +// +// Parameters: +// - key: The key to delete. +// +// Usage: +// +// d.Delete("key") func (d *data) Delete(key string) { d.Lock() delete(d.Data, key) d.Unlock() } +// Keys retrieves all keys in the data map. +// +// Returns: +// - []string: A slice of all keys in the data map. +// +// Usage: +// +// keys := d.Keys() func (d *data) Keys() []string { d.Lock() keys := make([]string, 0, len(d.Data)) @@ -59,6 +108,14 @@ func (d *data) Keys() []string { return keys } +// Len returns the number of key-value pairs in the data map. +// +// Returns: +// - int: The number of key-value pairs. +// +// Usage: +// +// length := d.Len() func (d *data) Len() int { return len(d.Data) } diff --git a/middleware/session/data_msgp.go b/middleware/session/data_msgp.go index a93ffcfb27..ce3af1bd17 100644 --- a/middleware/session/data_msgp.go +++ b/middleware/session/data_msgp.go @@ -3,182 +3,236 @@ package session // Code generated by github.com/tinylib/msgp DO NOT EDIT. import ( - "github.com/tinylib/msgp/msgp" + "github.com/tinylib/msgp/msgp" ) // DecodeMsg implements msgp.Decodable +// +// This method decodes the session data from the provided msgp.Reader. +// +// Parameters: +// - dc: The msgp.Reader to decode from. +// +// Returns: +// - error: An error if the decoding fails. +// +// Usage: +// err := d.DecodeMsg(reader) func (z *data) DecodeMsg(dc *msgp.Reader) (err error) { - var field []byte - _ = field - var zb0001 uint32 - zb0001, err = dc.ReadMapHeader() - if err != nil { - err = msgp.WrapError(err) - return - } - for zb0001 > 0 { - zb0001-- - field, err = dc.ReadMapKeyPtr() - if err != nil { - err = msgp.WrapError(err) - return - } - switch msgp.UnsafeString(field) { - case "Data": - var zb0002 uint32 - zb0002, err = dc.ReadMapHeader() - if err != nil { - err = msgp.WrapError(err, "Data") - return - } - if z.Data == nil { - z.Data = make(map[string]interface{}, zb0002) - } else if len(z.Data) > 0 { - for key := range z.Data { - delete(z.Data, key) - } - } - for zb0002 > 0 { - zb0002-- - var za0001 string - var za0002 interface{} - za0001, err = dc.ReadString() - if err != nil { - err = msgp.WrapError(err, "Data") - return - } - za0002, err = dc.ReadIntf() - if err != nil { - err = msgp.WrapError(err, "Data", za0001) - return - } - z.Data[za0001] = za0002 - } - default: - err = dc.Skip() - if err != nil { - err = msgp.WrapError(err) - return - } - } - } - return + var field []byte + _ = field + var zb0001 uint32 + zb0001, err = dc.ReadMapHeader() + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, err = dc.ReadMapKeyPtr() + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "Data": + var zb0002 uint32 + zb0002, err = dc.ReadMapHeader() + if err != nil { + err = msgp.WrapError(err, "Data") + return + } + if z.Data == nil { + z.Data = make(map[string]interface{}, zb0002) + } else if len(z.Data) > 0 { + for key := range z.Data { + delete(z.Data, key) + } + } + for zb0002 > 0 { + zb0002-- + var za0001 string + var za0002 interface{} + za0001, err = dc.ReadString() + if err != nil { + err = msgp.WrapError(err, "Data") + return + } + za0002, err = dc.ReadIntf() + if err != nil { + err = msgp.WrapError(err, "Data", za0001) + return + } + z.Data[za0001] = za0002 + } + default: + err = dc.Skip() + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + return } // EncodeMsg implements msgp.Encodable +// +// This method encodes the session data to the provided msgp.Writer. +// +// Parameters: +// - en: The msgp.Writer to encode to. +// +// Returns: +// - error: An error if the encoding fails. +// +// Usage: +// err := d.EncodeMsg(writer) func (z *data) EncodeMsg(en *msgp.Writer) (err error) { - // map header, size 1 - // write "Data" - err = en.Append(0x81, 0xa4, 0x44, 0x61, 0x74, 0x61) - if err != nil { - return - } - err = en.WriteMapHeader(uint32(len(z.Data))) - if err != nil { - err = msgp.WrapError(err, "Data") - return - } - for za0001, za0002 := range z.Data { - err = en.WriteString(za0001) - if err != nil { - err = msgp.WrapError(err, "Data") - return - } - err = en.WriteIntf(za0002) - if err != nil { - err = msgp.WrapError(err, "Data", za0001) - return - } - } - return + // map header, size 1 + // write "Data" + err = en.Append(0x81, 0xa4, 0x44, 0x61, 0x74, 0x61) + if err != nil { + return + } + err = en.WriteMapHeader(uint32(len(z.Data))) + if err != nil { + err = msgp.WrapError(err, "Data") + return + } + for za0001, za0002 := range z.Data { + err = en.WriteString(za0001) + if err != nil { + err = msgp.WrapError(err, "Data") + return + } + err = en.WriteIntf(za0002) + if err != nil { + err = msgp.WrapError(err, "Data", za0001) + return + } + } + return } // MarshalMsg implements msgp.Marshaler +// +// This method marshals the session data into a byte slice. +// +// Parameters: +// - b: The byte slice to marshal into. +// +// Returns: +// - []byte: The marshaled byte slice. +// - error: An error if the marshaling fails. +// +// Usage: +// b, err := d.MarshalMsg(nil) func (z *data) MarshalMsg(b []byte) (o []byte, err error) { - o = msgp.Require(b, z.Msgsize()) - // map header, size 1 - // string "Data" - o = append(o, 0x81, 0xa4, 0x44, 0x61, 0x74, 0x61) - o = msgp.AppendMapHeader(o, uint32(len(z.Data))) - for za0001, za0002 := range z.Data { - o = msgp.AppendString(o, za0001) - o, err = msgp.AppendIntf(o, za0002) - if err != nil { - err = msgp.WrapError(err, "Data", za0001) - return - } - } - return + o = msgp.Require(b, z.Msgsize()) + // map header, size 1 + // string "Data" + o = append(o, 0x81, 0xa4, 0x44, 0x61, 0x74, 0x61) + o = msgp.AppendMapHeader(o, uint32(len(z.Data))) + for za0001, za0002 := range z.Data { + o = msgp.AppendString(o, za0001) + o, err = msgp.AppendIntf(o, za0002) + if err != nil { + err = msgp.WrapError(err, "Data", za0001) + return + } + } + return } // UnmarshalMsg implements msgp.Unmarshaler +// +// This method unmarshals the session data from a byte slice. +// +// Parameters: +// - bts: The byte slice to unmarshal from. +// +// Returns: +// - []byte: The remaining byte slice after unmarshaling. +// - error: An error if the unmarshaling fails. +// +// Usage: +// b, err := d.UnmarshalMsg(bts) func (z *data) UnmarshalMsg(bts []byte) (o []byte, err error) { - var field []byte - _ = field - var zb0001 uint32 - zb0001, bts, err = msgp.ReadMapHeaderBytes(bts) - if err != nil { - err = msgp.WrapError(err) - return - } - for zb0001 > 0 { - zb0001-- - field, bts, err = msgp.ReadMapKeyZC(bts) - if err != nil { - err = msgp.WrapError(err) - return - } - switch msgp.UnsafeString(field) { - case "Data": - var zb0002 uint32 - zb0002, bts, err = msgp.ReadMapHeaderBytes(bts) - if err != nil { - err = msgp.WrapError(err, "Data") - return - } - if z.Data == nil { - z.Data = make(map[string]interface{}, zb0002) - } else if len(z.Data) > 0 { - for key := range z.Data { - delete(z.Data, key) - } - } - for zb0002 > 0 { - var za0001 string - var za0002 interface{} - zb0002-- - za0001, bts, err = msgp.ReadStringBytes(bts) - if err != nil { - err = msgp.WrapError(err, "Data") - return - } - za0002, bts, err = msgp.ReadIntfBytes(bts) - if err != nil { - err = msgp.WrapError(err, "Data", za0001) - return - } - z.Data[za0001] = za0002 - } - default: - bts, err = msgp.Skip(bts) - if err != nil { - err = msgp.WrapError(err) - return - } - } - } - o = bts - return + var field []byte + _ = field + var zb0001 uint32 + zb0001, bts, err = msgp.ReadMapHeaderBytes(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, bts, err = msgp.ReadMapKeyZC(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "Data": + var zb0002 uint32 + zb0002, bts, err = msgp.ReadMapHeaderBytes(bts) + if err != nil { + err = msgp.WrapError(err, "Data") + return + } + if z.Data == nil { + z.Data = make(map[string]interface{}, zb0002) + } else if len(z.Data) > 0 { + for key := range z.Data { + delete(z.Data, key) + } + } + for zb0002 > 0 { + var za0001 string + var za0002 interface{} + zb0002-- + za0001, bts, err = msgp.ReadStringBytes(bts) + if err != nil { + err = msgp.WrapError(err, "Data") + return + } + za0002, bts, err = msgp.ReadIntfBytes(bts) + if err != nil { + err = msgp.WrapError(err, "Data", za0001) + return + } + z.Data[za0001] = za0002 + } + default: + bts, err = msgp.Skip(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + o = bts + return } // Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message +// +// This method returns the estimated size of the serialized session data. +// +// Returns: +// - int: The estimated size in bytes. +// +// Usage: +// size := d.Msgsize() func (z *data) Msgsize() (s int) { - s = 1 + 5 + msgp.MapHeaderSize - if z.Data != nil { - for za0001, za0002 := range z.Data { - _ = za0002 - s += msgp.StringPrefixSize + len(za0001) + msgp.GuessSize(za0002) - } - } - return -} + s = 1 + 5 + msgp.MapHeaderSize + if z.Data != nil { + for za0001, za0002 := range z.Data { + _ = za0002 + s += msgp.StringPrefixSize + len(za0001) + msgp.GuessSize(za0002) + } + } + return +} \ No newline at end of file diff --git a/middleware/session/middleware.go b/middleware/session/middleware.go index f92de96abe..2da0f429f4 100644 --- a/middleware/session/middleware.go +++ b/middleware/session/middleware.go @@ -8,8 +8,7 @@ import ( "github.com/gofiber/fiber/v3/log" ) -// Session defines the session middleware configuration - +// Middleware defines the session middleware configuration type Middleware struct { Session *Session ctx *fiber.Ctx @@ -33,11 +32,17 @@ var ( } ) -// Session is a middleware to manage session state +// New creates a new session middleware with the given configuration. +// +// Parameters: +// - config: Variadic parameter to override default config. +// +// Returns: +// - fiber.Handler: The Fiber handler for the session middleware. // -// Session middleware manages common session state between requests. -// This middleware is dependent on the session store, which is responsible for -// storing the session data. +// Usage: +// +// app.Use(session.New()) func New(config ...Config) fiber.Handler { var handler fiber.Handler if len(config) > 0 { @@ -49,7 +54,18 @@ func New(config ...Config) fiber.Handler { return handler } -// NewWithStore returns a new session middleware with the given store +// NewWithStore returns a new session middleware with the given store. +// +// Parameters: +// - config: Variadic parameter to override default config. +// +// Returns: +// - fiber.Handler: The Fiber handler for the session middleware. +// - *Store: The session store. +// +// Usage: +// +// handler, store := session.NewWithStore() func NewWithStore(config ...Config) (fiber.Handler, *Store) { cfg := configDefault(config...) @@ -115,7 +131,14 @@ func NewWithStore(config ...Config) (fiber.Handler, *Store) { return handler, cfg.Store } -// acquireMiddleware returns a new Middleware from the pool +// acquireMiddleware returns a new Middleware from the pool. +// +// Returns: +// - *Middleware: The middleware object. +// +// Usage: +// +// m := acquireMiddleware() func acquireMiddleware() *Middleware { middleware, ok := middlewarePool.Get().(*Middleware) if !ok { @@ -124,7 +147,14 @@ func acquireMiddleware() *Middleware { return middleware } -// releaseMiddleware returns a Middleware to the pool +// releaseMiddleware returns a Middleware to the pool. +// +// Parameters: +// - m: The middleware object to release. +// +// Usage: +// +// releaseMiddleware(m) func releaseMiddleware(m *Middleware) { m.mu.Lock() m.config = Config{} @@ -136,7 +166,17 @@ func releaseMiddleware(m *Middleware) { middlewarePool.Put(m) } -// FromContext returns the Middleware from the fiber context +// FromContext returns the Middleware from the Fiber context. +// +// Parameters: +// - c: The Fiber context. +// +// Returns: +// - *Middleware: The middleware object if found, otherwise nil. +// +// Usage: +// +// m := session.FromContext(c) func FromContext(c fiber.Ctx) *Middleware { m, ok := c.Locals(key).(*Middleware) if !ok { @@ -146,6 +186,15 @@ func FromContext(c fiber.Ctx) *Middleware { return m } +// Set sets a key-value pair in the session. +// +// Parameters: +// - key: The key to set. +// - value: The value to set. +// +// Usage: +// +// m.Set("key", "value") func (m *Middleware) Set(key string, value any) { m.mu.Lock() defer m.mu.Unlock() @@ -154,6 +203,17 @@ func (m *Middleware) Set(key string, value any) { m.hasChanged = true } +// Get retrieves a value from the session by key. +// +// Parameters: +// - key: The key to retrieve. +// +// Returns: +// - any: The value associated with the key. +// +// Usage: +// +// value := m.Get("key") func (m *Middleware) Get(key string) any { m.mu.RLock() defer m.mu.RUnlock() @@ -161,6 +221,14 @@ func (m *Middleware) Get(key string) any { return m.Session.Get(key) } +// Delete removes a key-value pair from the session. +// +// Parameters: +// - key: The key to delete. +// +// Usage: +// +// m.Delete("key") func (m *Middleware) Delete(key string) { m.mu.Lock() defer m.mu.Unlock() @@ -169,6 +237,14 @@ func (m *Middleware) Delete(key string) { m.hasChanged = true } +// Destroy destroys the session. +// +// Returns: +// - error: An error if the destruction fails. +// +// Usage: +// +// err := m.Destroy() func (m *Middleware) Destroy() error { m.mu.Lock() defer m.mu.Unlock() @@ -178,14 +254,38 @@ func (m *Middleware) Destroy() error { return err } +// Fresh checks if the session is fresh. +// +// Returns: +// - bool: True if the session is fresh, otherwise false. +// +// Usage: +// +// isFresh := m.Fresh() func (m *Middleware) Fresh() bool { return m.Session.Fresh() } +// ID returns the session ID. +// +// Returns: +// - string: The session ID. +// +// Usage: +// +// id := m.ID() func (m *Middleware) ID() string { return m.Session.ID() } +// Reset resets the session. +// +// Returns: +// - error: An error if the reset fails. +// +// Usage: +// +// err := m.Reset() func (m *Middleware) Reset() error { m.mu.Lock() defer m.mu.Unlock() @@ -195,7 +295,14 @@ func (m *Middleware) Reset() error { return err } -// Store returns the session store +// Store returns the session store. +// +// Returns: +// - *Store: The session store. +// +// Usage: +// +// store := m.Store() func (m *Middleware) Store() *Store { return m.config.Store } diff --git a/middleware/session/session.go b/middleware/session/session.go index 7834ecca4b..eae6db8cc2 100644 --- a/middleware/session/session.go +++ b/middleware/session/session.go @@ -12,6 +12,7 @@ import ( "github.com/valyala/fasthttp" ) +// Session represents a user session. type Session struct { ctx fiber.Ctx // fiber context config *Store // store configuration @@ -29,6 +30,14 @@ var sessionPool = sync.Pool{ }, } +// acquireSession returns a new Session from the pool. +// +// Returns: +// - *Session: The session object. +// +// Usage: +// +// s := acquireSession() func acquireSession() *Session { s := sessionPool.Get().(*Session) //nolint:forcetypeassert,errcheck // We store nothing else in the pool if s.data == nil { @@ -48,6 +57,10 @@ func acquireSession() *Session { // to improve the performance of the session store. // // The session should not be used after calling this function. +// +// Usage: +// +// s.Release() func (s *Session) Release() { if s == nil { return @@ -71,21 +84,45 @@ func releaseSession(s *Session) { sessionPool.Put(s) } -// Fresh is true if the current session is new +// Fresh returns true if the current session is new. +// +// Returns: +// - bool: True if the session is fresh, otherwise false. +// +// Usage: +// +// isFresh := s.Fresh() func (s *Session) Fresh() bool { s.mu.RLock() defer s.mu.RUnlock() return s.fresh } -// ID returns the session id +// ID returns the session id. +// +// Returns: +// - string: The session ID. +// +// Usage: +// +// id := s.ID() func (s *Session) ID() string { s.mu.RLock() defer s.mu.RUnlock() return s.id } -// Get will return the value +// Get returns the value associated with the given key. +// +// Parameters: +// - key: The key to retrieve. +// +// Returns: +// - any: The value associated with the key. +// +// Usage: +// +// value := s.Get("key") func (s *Session) Get(key string) any { // Better safe than sorry if s.data == nil { @@ -94,7 +131,15 @@ func (s *Session) Get(key string) any { return s.data.Get(key) } -// Set will update or create a new key value +// Set updates or creates a new key-value pair in the session. +// +// Parameters: +// - key: The key to set. +// - val: The value to set. +// +// Usage: +// +// s.Set("key", "value") func (s *Session) Set(key string, val any) { // Better safe than sorry if s.data == nil { @@ -103,7 +148,14 @@ func (s *Session) Set(key string, val any) { s.data.Set(key, val) } -// Delete will delete the value +// Delete removes the key-value pair from the session. +// +// Parameters: +// - key: The key to delete. +// +// Usage: +// +// s.Delete("key") func (s *Session) Delete(key string) { // Better safe than sorry if s.data == nil { @@ -112,7 +164,14 @@ func (s *Session) Delete(key string) { s.data.Delete(key) } -// Destroy will delete the session from Storage and expire session cookie +// Destroy deletes the session from storage and expires the session cookie. +// +// Returns: +// - error: An error if the destruction fails. +// +// Usage: +// +// err := s.Destroy() func (s *Session) Destroy() error { // Better safe than sorry if s.data == nil { @@ -135,7 +194,14 @@ func (s *Session) Destroy() error { return nil } -// Regenerate generates a new session id and delete the old one from Storage +// Regenerate generates a new session id and deletes the old one from storage. +// +// Returns: +// - error: An error if the regeneration fails. +// +// Usage: +// +// err := s.Regenerate() func (s *Session) Regenerate() error { s.mu.Lock() defer s.mu.Unlock() @@ -151,7 +217,14 @@ func (s *Session) Regenerate() error { return nil } -// Reset generates a new session id, deletes the old one from storage, and resets the associated data +// Reset generates a new session id, deletes the old one from storage, and resets the associated data. +// +// Returns: +// - error: An error if the reset fails. +// +// Usage: +// +// err := s.Reset() func (s *Session) Reset() error { // Reset local data if s.data != nil { @@ -182,18 +255,25 @@ func (s *Session) Reset() error { return nil } -// refresh generates a new session, and set session.fresh to be true +// refresh generates a new session, and sets session.fresh to be true. func (s *Session) refresh() { s.id = s.config.KeyGenerator() s.fresh = true } -// Save will update the storage and client cookie +// Save updates the storage and client cookie. // // sess.Save() will save the session data to the storage and update the // client cookie, and it will release the session after saving. // // It's not safe to use the session after calling Save(). +// +// Returns: +// - error: An error if the save operation fails. +// +// Usage: +// +// err := s.Save() func (s *Session) Save() error { // If the session is being used in the handler, it should not be saved if _, ok := s.ctx.Locals(key).(*Middleware); ok { @@ -241,7 +321,14 @@ func (s *Session) saveSession() error { return nil } -// Keys will retrieve all keys in current session +// Keys retrieves all keys in the current session. +// +// Returns: +// - []string: A slice of all keys in the session. +// +// Usage: +// +// keys := s.Keys() func (s *Session) Keys() []string { if s.data == nil { return []string{} @@ -249,7 +336,14 @@ func (s *Session) Keys() []string { return s.data.Keys() } -// SetExpiry sets a specific expiration for this session +// SetIdleTimeout sets a specific expiration for this session. +// +// Parameters: +// - idleTimeout: The duration for the idle timeout. +// +// Usage: +// +// s.SetIdleTimeout(time.Hour) func (s *Session) SetIdleTimeout(idleTimeout time.Duration) { s.mu.Lock() defer s.mu.Unlock() @@ -320,6 +414,16 @@ func (s *Session) delSession() { } // decodeSessionData decodes the session data from raw bytes. +// +// Parameters: +// - rawData: The raw byte data to decode. +// +// Returns: +// - error: An error if the decoding fails. +// +// Usage: +// +// err := s.decodeSessionData(rawData) func (s *Session) decodeSessionData(rawData []byte) error { _, _ = s.byteBuffer.Write(rawData) encCache := gob.NewDecoder(s.byteBuffer) diff --git a/middleware/session/store.go b/middleware/session/store.go index 3ddf28bcc7..9fe2ff868e 100644 --- a/middleware/session/store.go +++ b/middleware/session/store.go @@ -44,14 +44,35 @@ func newStore(config ...Config) *Store { } // RegisterType registers a custom type for encoding/decoding into any storage provider. +// +// Parameters: +// - i: The custom type to register. +// +// Usage: +// +// store.RegisterType(MyCustomType{}) func (*Store) RegisterType(i any) { gob.Register(i) } -// Get will get/create a session +// Get will get/create a session. // // This function will return an ErrSessionAlreadyLoadedByMiddleware if -// the session is already loaded by the middleware +// the session is already loaded by the middleware. +// +// Parameters: +// - c: The Fiber context. +// +// Returns: +// - *Session: The session object. +// - error: An error if the session retrieval fails or if the session is already loaded by the middleware. +// +// Usage: +// +// sess, err := store.Get(c) +// if err != nil { +// // handle error +// } func (s *Store) Get(c fiber.Ctx) (*Session, error) { // If session is already loaded in the context, // it should not be loaded again @@ -63,7 +84,21 @@ func (s *Store) Get(c fiber.Ctx) (*Session, error) { return s.getSession(c) } -// Get session based on context +// getSession retrieves a session based on the context. +// +// Parameters: +// - c: The Fiber context. +// +// Returns: +// - *Session: The session object. +// - error: An error if the session retrieval fails. +// +// Usage: +// +// sess, err := store.getSession(c) +// if err != nil { +// // handle error +// } func (s *Store) getSession(c fiber.Ctx) (*Session, error) { var rawData []byte var err error @@ -118,6 +153,16 @@ func (s *Store) getSession(c fiber.Ctx) (*Session, error) { } // getSessionID returns the session ID from cookies, headers, or query string. +// +// Parameters: +// - c: The Fiber context. +// +// Returns: +// - string: The session ID. +// +// Usage: +// +// id := store.getSessionID(c) func (s *Store) getSessionID(c fiber.Ctx) string { id := c.Cookies(s.sessionName) if len(id) > 0 { @@ -142,11 +187,34 @@ func (s *Store) getSessionID(c fiber.Ctx) string { } // Reset deletes all sessions from the storage. +// +// Returns: +// - error: An error if the reset operation fails. +// +// Usage: +// +// err := store.Reset() +// if err != nil { +// // handle error +// } func (s *Store) Reset() error { return s.Storage.Reset() } // Delete deletes a session by its ID. +// +// Parameters: +// - id: The unique identifier of the session. +// +// Returns: +// - error: An error if the deletion fails or if the session ID is empty. +// +// Usage: +// +// err := store.Delete(id) +// if err != nil { +// // handle error +// } func (s *Store) Delete(id string) error { if id == "" { return ErrEmptySessionID @@ -174,6 +242,13 @@ func (s *Store) Delete(id string) error { // Returns: // - *Session: The session object if found, otherwise nil. // - error: An error if the session retrieval fails or if the session ID is empty. +// +// Usage: +// +// sess, err := store.GetSessionByID(id) +// if err != nil { +// // handle error +// } func (s *Store) GetSessionByID(id string) (*Session, error) { if id == "" { return nil, ErrEmptySessionID From 56f6ce059355466435ad88554040195c3a6de182 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Sun, 8 Sep 2024 15:25:09 -0300 Subject: [PATCH 26/79] docs: Security Note to examples --- docs/middleware/session.md | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/docs/middleware/session.md b/docs/middleware/session.md index 4d5f63d76c..c894df1055 100644 --- a/docs/middleware/session.md +++ b/docs/middleware/session.md @@ -212,6 +212,21 @@ func (s *Store) GetSessionByID(id string) (*Session, error) ## Examples +:::note +**Security Note**: Fiber’s session middleware uses cookies with `SameSite=Lax` by default, which provides basic CSRF protection for most GET requests. However, for comprehensive security—especially for POST requests or sensitive operations (e.g., account changes, transactions, form submissions)—it is strongly recommended to use CSRF protection middleware. + +### Recommendations: +1. **Session Middleware Without CSRF**: + - You can use the `session` middleware without the `csrf` middleware or rely solely on `SameSite=Lax` for basic protection in low-risk scenarios. + +2. **Double Submit Cookie Pattern**: + - You can implement the **double submit cookie pattern** (via custom, third-party, or built-in middleware), where the CSRF token is stored in a cookie, and the request includes the token in a hidden field or header. In this approach, there is no need to pass the `session.Store` to the `csrf` middleware. Simply apply the `session.New()` and `csrf.New()` middleware to the routes you want to protect. + +3. **Recommended Approach**: + - For stronger protection, especially in high-risk scenarios, use the `csrf` middleware with the session store. This method implements the **Synchronizer Token Pattern**, providing robust defense by associating the CSRF token with the user’s session. This approach requires passing the `session.Store` to the `csrf` middleware. + - Ensure the CSRF token is embedded in forms or included in a header for POST requests and verified on the server side for incoming requests. This adds a crucial security layer for state-changing actions. +::: + ### As a Middleware Handler (Recommended) ```go From 9e406f4eaa9254df1db32b629eaf9e1ee247d8a4 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Sun, 8 Sep 2024 15:28:31 -0300 Subject: [PATCH 27/79] docs: Add recommendation for CSRF protection in session middleware --- docs/middleware/session.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/middleware/session.md b/docs/middleware/session.md index c894df1055..4d45c49aa8 100644 --- a/docs/middleware/session.md +++ b/docs/middleware/session.md @@ -213,7 +213,7 @@ func (s *Store) GetSessionByID(id string) (*Session, error) ## Examples :::note -**Security Note**: Fiber’s session middleware uses cookies with `SameSite=Lax` by default, which provides basic CSRF protection for most GET requests. However, for comprehensive security—especially for POST requests or sensitive operations (e.g., account changes, transactions, form submissions)—it is strongly recommended to use CSRF protection middleware. +**Security Note**: Fiber’s session middleware uses cookies with `SameSite=Lax` by default, which provides basic CSRF protection for most GET requests. However, for comprehensive security—especially for POST requests or sensitive operations (e.g., account changes, transactions, form submissions)—it is strongly recommended to use CSRF protection middleware. Fiber provides a `csrf` middleware that can be used in conjunction with the `session` middleware for robust protection. Find more information in the [CSRF Middleware](https://docs.gofiber.io/api/middleware/csrf) documentation. ### Recommendations: 1. **Session Middleware Without CSRF**: From 12b219a6e80affcbede123c526a8a60d2e96eaec Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Sun, 8 Sep 2024 15:43:09 -0300 Subject: [PATCH 28/79] chore: markdown lint --- docs/middleware/session.md | 70 +++++++++++++++++++++++++++++++++++--- 1 file changed, 65 insertions(+), 5 deletions(-) diff --git a/docs/middleware/session.md b/docs/middleware/session.md index 4d45c49aa8..488be3a48c 100644 --- a/docs/middleware/session.md +++ b/docs/middleware/session.md @@ -28,7 +28,7 @@ We recommend using the `Middleware` handler for better integration with other mi - [Examples](#examples) - [As a Middleware Handler (Recommended)](#as-a-middleware-handler-recommended) - [Using a Custom Storage](#using-a-custom-storage) - - [Session without Middleware Handler](#session-without-middleware-handler) + - [Session Without Middleware Handler](#session-without-middleware-handler) - [Using Custom Types in Session Data](#using-custom-types-in-session-data) - [Config](#config) - [Default Config](#default-config) @@ -38,7 +38,7 @@ We recommend using the `Middleware` handler for better integration with other mi ### v2 to v3 - The `New` function signature has changed in v3. It now returns a `*Middleware` instead of a `*Store`. You can access the store using the `Store` method on the `*Middleware` or by using the `NewWithStore` function. - + While it's still possible to work with the `*Store` directly, we recommend using the `Middleware` handler for better integration with other Fiber middlewares. For more information about changes in Fiber v3, see [What's New](https://github.com/gofiber/fiber/blob/main/docs/whats_new.md). @@ -215,7 +215,8 @@ func (s *Store) GetSessionByID(id string) (*Session, error) :::note **Security Note**: Fiber’s session middleware uses cookies with `SameSite=Lax` by default, which provides basic CSRF protection for most GET requests. However, for comprehensive security—especially for POST requests or sensitive operations (e.g., account changes, transactions, form submissions)—it is strongly recommended to use CSRF protection middleware. Fiber provides a `csrf` middleware that can be used in conjunction with the `session` middleware for robust protection. Find more information in the [CSRF Middleware](https://docs.gofiber.io/api/middleware/csrf) documentation. -### Recommendations: +### Recommendations + 1. **Session Middleware Without CSRF**: - You can use the `session` middleware without the `csrf` middleware or rely solely on `SameSite=Lax` for basic protection in low-risk scenarios. @@ -225,6 +226,7 @@ func (s *Store) GetSessionByID(id string) (*Session, error) 3. **Recommended Approach**: - For stronger protection, especially in high-risk scenarios, use the `csrf` middleware with the session store. This method implements the **Synchronizer Token Pattern**, providing robust defense by associating the CSRF token with the user’s session. This approach requires passing the `session.Store` to the `csrf` middleware. - Ensure the CSRF token is embedded in forms or included in a header for POST requests and verified on the server side for incoming requests. This adds a crucial security layer for state-changing actions. + ::: ### As a Middleware Handler (Recommended) @@ -304,7 +306,7 @@ func main() { } ``` -### Session without Middleware Handler +### Session Without Middleware Handler This example shows how to work with sessions directly without the middleware handler. @@ -348,6 +350,64 @@ func main() { } ``` +### Using Custom Types in Session Data + +Session data can only be of the following types by default: + +- `string` +- `int` +- `int8` +- `int16` +- `int32` +- `int64` +- `uint` +- `uint8` +- `uint16` +- `uint32` +- `uint64` +- `bool` +- `float32` +- `float64` +- `[]byte` +- `complex64` +- `complex128` +- `interface{}` + +To support other types in session data, you can register custom types. Here is an example of how to register a custom type: + +```go +package main + +import ( + "log" + + "github.com/gofiber/fiber/v3" + "github.com/gofiber/session/v3" + "github.com/gofiber/session/v3/middleware/session" +) + +type User struct { + Name string + Age int +} + +func main() { + // Create a new Fiber app + app := fiber.New() + + // Initialize custom session config + sessionMiddleware, sessionStore := session.NewWithStore() + + // Register custom type + sessionStore.RegisterType(User{}) + + // Use the session middleware + app.Use(sessionMiddleware) + + ... +} +``` + ## Config | Property | Type | Description | Default | @@ -386,4 +446,4 @@ session.Config{ CookieHTTPOnly: false, CookieSessionOnly: false, } -``` \ No newline at end of file +``` From 6812fc40561c75a214ac1ca608cecf141f3eb44c Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Sun, 8 Sep 2024 15:53:18 -0300 Subject: [PATCH 29/79] docs: Update session middleware docs --- docs/middleware/session.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/middleware/session.md b/docs/middleware/session.md index 488be3a48c..5439ef55b5 100644 --- a/docs/middleware/session.md +++ b/docs/middleware/session.md @@ -36,8 +36,7 @@ We recommend using the `Middleware` handler for better integration with other mi ## Migration Guide ### v2 to v3 - -- The `New` function signature has changed in v3. It now returns a `*Middleware` instead of a `*Store`. You can access the store using the `Store` method on the `*Middleware` or by using the `NewWithStore` function. +- In version 3, the `New` function signature has been updated. It now returns a Fiber middleware handler instead of a `*Store`. To access the store, you can use the `Store` method on the `*Middleware` (obtained by calling `session.FromContext(c)` in a handler where the middleware is applied) or utilize the `NewWithStore` function. While it's still possible to work with the `*Store` directly, we recommend using the `Middleware` handler for better integration with other Fiber middlewares. From 28aad65059cc87b60e23740f1473402aac4ab770 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Sun, 8 Sep 2024 15:55:34 -0300 Subject: [PATCH 30/79] docs: makrdown lint --- docs/middleware/session.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/middleware/session.md b/docs/middleware/session.md index 5439ef55b5..6a02c78f15 100644 --- a/docs/middleware/session.md +++ b/docs/middleware/session.md @@ -36,6 +36,7 @@ We recommend using the `Middleware` handler for better integration with other mi ## Migration Guide ### v2 to v3 + - In version 3, the `New` function signature has been updated. It now returns a Fiber middleware handler instead of a `*Store`. To access the store, you can use the `Store` method on the `*Middleware` (obtained by calling `session.FromContext(c)` in a handler where the middleware is applied) or utilize the `NewWithStore` function. While it's still possible to work with the `*Store` directly, we recommend using the `Middleware` handler for better integration with other Fiber middlewares. From 14c7a6cff9a36f5c63579baed66085bf1bc1a4fb Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Fri, 13 Sep 2024 14:00:09 -0300 Subject: [PATCH 31/79] test(middleware/session): Add unit tests for session config.go --- middleware/session/config_test.go | 59 +++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 middleware/session/config_test.go diff --git a/middleware/session/config_test.go b/middleware/session/config_test.go new file mode 100644 index 0000000000..1a4825206f --- /dev/null +++ b/middleware/session/config_test.go @@ -0,0 +1,59 @@ +package session + +import ( + "testing" + "time" + + "github.com/gofiber/fiber/v3" + "github.com/stretchr/testify/assert" + "github.com/valyala/fasthttp" +) + +func TestConfigDefault(t *testing.T) { + // Test default config + cfg := configDefault() + assert.Equal(t, 24*time.Hour, cfg.IdleTimeout) + assert.Equal(t, "cookie:session_id", cfg.KeyLookup) + assert.NotNil(t, cfg.KeyGenerator) + assert.Equal(t, SourceCookie, cfg.source) + assert.Equal(t, "session_id", cfg.sessionName) +} + +func TestConfigDefaultWithCustomConfig(t *testing.T) { + // Test custom config + customConfig := Config{ + IdleTimeout: 48 * time.Hour, + KeyLookup: "header:custom_session_id", + KeyGenerator: func() string { return "custom_key" }, + } + cfg := configDefault(customConfig) + assert.Equal(t, 48*time.Hour, cfg.IdleTimeout) + assert.Equal(t, "header:custom_session_id", cfg.KeyLookup) + assert.NotNil(t, cfg.KeyGenerator) + assert.Equal(t, SourceHeader, cfg.source) + assert.Equal(t, "custom_session_id", cfg.sessionName) +} + +func TestDefaultErrorHandler(t *testing.T) { + // Create a new Fiber app + app := fiber.New() + + // Create a new context + ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) + + // Test DefaultErrorHandler + DefaultErrorHandler(&ctx, fiber.ErrInternalServerError) + assert.Equal(t, fiber.StatusInternalServerError, ctx.Response().StatusCode()) +} + +func TestInvalidKeyLookupFormat(t *testing.T) { + assert.PanicsWithValue(t, "[session] KeyLookup must in the form of :", func() { + configDefault(Config{KeyLookup: "invalid_format"}) + }) +} + +func TestUnsupportedSource(t *testing.T) { + assert.PanicsWithValue(t, "[session] source is not supported", func() { + configDefault(Config{KeyLookup: "unsupported:session_id"}) + }) +} From a865ba565823d0eb1edecc1006171d8c68c42bc1 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Fri, 13 Sep 2024 14:36:26 -0300 Subject: [PATCH 32/79] test(middleware/session): Add unit tests for store.go --- middleware/session/store_test.go | 82 ++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/middleware/session/store_test.go b/middleware/session/store_test.go index 78a016841f..4886b0603c 100644 --- a/middleware/session/store_test.go +++ b/middleware/session/store_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/gofiber/fiber/v3" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" ) @@ -116,3 +117,84 @@ func Test_Store_DeleteSession(t *testing.T) { // The session ID should be different now, because the old session was deleted require.NotEqual(t, sessionID, session.ID()) } + +func TestStore_Get_SessionAlreadyLoaded(t *testing.T) { + // Create a new Fiber app + app := fiber.New() + + // Create a new context + ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) + + // Mock middleware and set it in the context + middleware := &Middleware{} + ctx.Locals(key, middleware) + + // Create a new store + store := &Store{} + + // Call the Get method + sess, err := store.Get(ctx) + + // Assert that the error is ErrSessionAlreadyLoadedByMiddleware + assert.Nil(t, sess) + assert.Equal(t, ErrSessionAlreadyLoadedByMiddleware, err) +} + +func TestStore_Delete(t *testing.T) { + // Create a new store + store := newStore() + + t.Run("delete with empty session ID", func(t *testing.T) { + err := store.Delete("") + assert.Error(t, err) + assert.Equal(t, ErrEmptySessionID, err) + }) + + t.Run("delete non-existing session", func(t *testing.T) { + err := store.Delete("non-existing-session-id") + assert.NoError(t, err) + }) +} + +func Test_Store_GetSessionByID(t *testing.T) { + t.Parallel() + // Create a new store + store := newStore() + + t.Run("empty session ID", func(t *testing.T) { + t.Parallel() + sess, err := store.GetSessionByID("") + require.Error(t, err) + assert.Nil(t, sess) + assert.Equal(t, ErrEmptySessionID, err) + }) + + t.Run("non-existent session ID", func(t *testing.T) { + t.Parallel() + sess, err := store.GetSessionByID("non-existent-session-id") + require.Error(t, err) + assert.Nil(t, sess) + assert.Equal(t, ErrSessionIDNotFoundInStore, err) + }) + + t.Run("valid session ID", func(t *testing.T) { + t.Parallel() + // Create a new session + ctx := fiber.New().AcquireCtx(&fasthttp.RequestCtx{}) + session, err := store.Get(ctx) + require.NoError(t, err) + + // Save the session ID + sessionID := session.ID() + + // Save the session + err = session.Save() + require.NoError(t, err) + + // Retrieve the session by ID + retrievedSession, err := store.GetSessionByID(sessionID) + require.NoError(t, err) + assert.NotNil(t, retrievedSession) + assert.Equal(t, sessionID, retrievedSession.ID()) + }) +} From eaedc6dd7a6e4b41bb4b7b12d9d27a6dc0731bb9 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Fri, 13 Sep 2024 14:44:55 -0300 Subject: [PATCH 33/79] test(middleware/session): Add data.go unit tests --- middleware/session/data_test.go | 178 ++++++++++++++++++++++++++++++++ 1 file changed, 178 insertions(+) create mode 100644 middleware/session/data_test.go diff --git a/middleware/session/data_test.go b/middleware/session/data_test.go new file mode 100644 index 0000000000..3a94ed4ad9 --- /dev/null +++ b/middleware/session/data_test.go @@ -0,0 +1,178 @@ +package session + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestKeys(t *testing.T) { + t.Parallel() + + // Test case: Empty data + t.Run("Empty data", func(t *testing.T) { + d := acquireData() + defer dataPool.Put(d) + keys := d.Keys() + assert.Empty(t, keys, "Expected no keys in empty data") + }) + + // Test case: Single key + t.Run("Single key", func(t *testing.T) { + d := acquireData() + defer dataPool.Put(d) + d.Set("key1", "value1") + keys := d.Keys() + assert.Len(t, keys, 1, "Expected one key") + assert.Contains(t, keys, "key1", "Expected key1 to be present") + }) + + // Test case: Multiple keys + t.Run("Multiple keys", func(t *testing.T) { + d := acquireData() + defer dataPool.Put(d) + d.Set("key1", "value1") + d.Set("key2", "value2") + d.Set("key3", "value3") + keys := d.Keys() + assert.Len(t, keys, 3, "Expected three keys") + assert.Contains(t, keys, "key1", "Expected key1 to be present") + assert.Contains(t, keys, "key2", "Expected key2 to be present") + assert.Contains(t, keys, "key3", "Expected key3 to be present") + }) + + // Test case: Concurrent access + t.Run("Concurrent access", func(t *testing.T) { + d := acquireData() + defer dataPool.Put(d) + d.Set("key1", "value1") + d.Set("key2", "value2") + d.Set("key3", "value3") + + done := make(chan bool) + go func() { + keys := d.Keys() + assert.Len(t, keys, 3, "Expected three keys") + done <- true + }() + go func() { + keys := d.Keys() + assert.Len(t, keys, 3, "Expected three keys") + done <- true + }() + <-done + <-done + }) +} + +func TestData_Len(t *testing.T) { + t.Parallel() + + // Test case: Empty data + t.Run("Empty data", func(t *testing.T) { + d := acquireData() + defer dataPool.Put(d) + length := d.Len() + assert.Equal(t, 0, length, "Expected length to be 0 for empty data") + }) + + // Test case: Single key + t.Run("Single key", func(t *testing.T) { + d := acquireData() + defer dataPool.Put(d) + d.Set("key1", "value1") + length := d.Len() + assert.Equal(t, 1, length, "Expected length to be 1 when one key is set") + }) + + // Test case: Multiple keys + t.Run("Multiple keys", func(t *testing.T) { + d := acquireData() + defer dataPool.Put(d) + d.Set("key1", "value1") + d.Set("key2", "value2") + d.Set("key3", "value3") + length := d.Len() + assert.Equal(t, 3, length, "Expected length to be 3 when three keys are set") + }) + + // Test case: Concurrent access + t.Run("Concurrent access", func(t *testing.T) { + d := acquireData() + defer dataPool.Put(d) + d.Set("key1", "value1") + d.Set("key2", "value2") + d.Set("key3", "value3") + + done := make(chan bool) + go func() { + length := d.Len() + assert.Equal(t, 3, length, "Expected length to be 3 during concurrent access") + done <- true + }() + go func() { + length := d.Len() + assert.Equal(t, 3, length, "Expected length to be 3 during concurrent access") + done <- true + }() + <-done + <-done + }) +} + +func TestData_Get(t *testing.T) { + t.Parallel() + + // Test case: Non-existent key + t.Run("Non-existent key", func(t *testing.T) { + d := acquireData() + defer dataPool.Put(d) + value := d.Get("non-existent-key") + assert.Nil(t, value, "Expected nil for non-existent key") + }) + + // Test case: Existing key + t.Run("Existing key", func(t *testing.T) { + d := acquireData() + defer dataPool.Put(d) + d.Set("key1", "value1") + value := d.Get("key1") + assert.Equal(t, "value1", value, "Expected value1 for key1") + }) +} + +func TestData_Reset(t *testing.T) { + t.Parallel() + + // Test case: Reset data + t.Run("Reset data", func(t *testing.T) { + d := acquireData() + defer dataPool.Put(d) + d.Set("key1", "value1") + d.Set("key2", "value2") + d.Reset() + assert.Empty(t, d.Data, "Expected data map to be empty after reset") + }) +} + +func TestData_Delete(t *testing.T) { + t.Parallel() + + // Test case: Delete existing key + t.Run("Delete existing key", func(t *testing.T) { + d := acquireData() + defer dataPool.Put(d) + d.Set("key1", "value1") + d.Delete("key1") + value := d.Get("key1") + assert.Nil(t, value, "Expected nil for deleted key") + }) + + // Test case: Delete non-existent key + t.Run("Delete non-existent key", func(t *testing.T) { + d := acquireData() + defer dataPool.Put(d) + d.Delete("non-existent-key") + // No assertion needed, just ensure no panic or error + }) +} From d2cf5b86358f52264a3884d5c830f16ad735f5d6 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Fri, 13 Sep 2024 16:01:49 -0300 Subject: [PATCH 34/79] refactor(middleware/session): session tests and add session release test - Refactor session tests to improve readability and maintainability. - Add a new test case to ensure proper session release functionality. - Update session.md --- docs/middleware/session.md | 155 ++++++++++++----------------- middleware/session/config_test.go | 28 +++--- middleware/session/data_test.go | 54 +++++++--- middleware/session/session.go | 12 ++- middleware/session/session_test.go | 32 ++++++ middleware/session/store_test.go | 27 +++-- 6 files changed, 169 insertions(+), 139 deletions(-) diff --git a/docs/middleware/session.md b/docs/middleware/session.md index 6a02c78f15..54ee031d64 100644 --- a/docs/middleware/session.md +++ b/docs/middleware/session.md @@ -4,11 +4,9 @@ id: session # Session Middleware for [Fiber](https://github.com/gofiber/fiber) -The `session` middleware provides session handling for Fiber applications. It leverages the [Storage](https://github.com/gofiber/storage) package to offer support for multiple databases through a unified interface. By default, session data is stored in memory, but you can easily switch to other storage options, as shown in the examples below. +The `session` middleware provides session management for Fiber applications, utilizing the [Storage](https://github.com/gofiber/storage) package for multi-database support via a unified interface. By default, session data is stored in memory, but custom storage options are easily configurable (see examples below). -:::note -We recommend using the `Middleware` handler for better integration with other middleware. See the [As a Middleware Handler (Recommended)](#as-a-middleware-handler-recommended) section for details. -::: +As of v3, we recommend using the middleware handler for session management. However, for backward compatibility, v2's session methods are still available, allowing you to continue using the session management techniques from earlier versions of Fiber. Both methods are demonstrated in the examples. ## Table of Contents @@ -26,10 +24,10 @@ We recommend using the `Middleware` handler for better integration with other mi - [Session Methods](#session-methods) - [Store Methods](#store-methods) - [Examples](#examples) - - [As a Middleware Handler (Recommended)](#as-a-middleware-handler-recommended) - - [Using a Custom Storage](#using-a-custom-storage) + - [Middleware Handler (Recommended)](#middleware-handler-recommended) + - [Custom Storage Example](#custom-storage-example) - [Session Without Middleware Handler](#session-without-middleware-handler) - - [Using Custom Types in Session Data](#using-custom-types-in-session-data) + - [Custom Types in Session Data](#custom-types-in-session-data) - [Config](#config) - [Default Config](#default-config) @@ -37,13 +35,23 @@ We recommend using the `Middleware` handler for better integration with other mi ### v2 to v3 -- In version 3, the `New` function signature has been updated. It now returns a Fiber middleware handler instead of a `*Store`. To access the store, you can use the `Store` method on the `*Middleware` (obtained by calling `session.FromContext(c)` in a handler where the middleware is applied) or utilize the `NewWithStore` function. +- **Function Signature Change**: In v3, the `New` function now returns a middleware handler instead of a `*Store`. To access the store, use the `Store` method on `*Middleware` (obtained from `session.FromContext(c)` in a handler) or use `NewWithStore`. + +- **Session Lifecycle Management**: The `*Store.Save` method no longer releases the instance automatically. You must manually call `sess.Release()` after using the session to manage its lifecycle properly. + +For more details about Fiber v3, see [What’s New](https://github.com/gofiber/fiber/blob/main/docs/whats_new.md). -While it's still possible to work with the `*Store` directly, we recommend using the `Middleware` handler for better integration with other Fiber middlewares. +### Migrating v2 to v3 Example (Legacy Approach) -For more information about changes in Fiber v3, see [What's New](https://github.com/gofiber/fiber/blob/main/docs/whats_new.md). +To convert a v2 example to use the v3 legacy approach, follow these steps: -#### v2 Example +1. **Initialize with Store**: Use `session.NewWithStore()` to obtain both the middleware handler and store. +2. **Retrieve Session**: Access the session store using the `store.Get(c)` method. +3. **Release Session**: Ensure that you call `sess.Release()` after you are done with the session to manage its lifecycle. + +#### Example Conversion + +**v2 Example:** ```go store := session.New() @@ -53,7 +61,7 @@ app.Get("/", func(c *fiber.Ctx) error { if err != nil { return err } - + key, ok := sess.Get("key").(string) if !ok { return c.SendStatus(fiber.StatusInternalServerError) @@ -70,7 +78,7 @@ app.Get("/", func(c *fiber.Ctx) error { }) ``` -#### v3 Example (Using Store) +**v3 Legacy Approach:** ```go _, store := session.NewWithStore() @@ -80,7 +88,8 @@ app.Get("/", func(c *fiber.Ctx) error { if err != nil { return err } - + defer sess.Release() // Important: Release the session + key, ok := sess.Get("key").(string) if !ok { return c.SendStatus(fiber.StatusInternalServerError) @@ -97,15 +106,15 @@ app.Get("/", func(c *fiber.Ctx) error { }) ``` -#### v3 Example (Using Middleware) +### v3 Example (Recommended Middleware Handler) -See the [As a Middleware Handler (Recommended)](#as-a-middleware-handler-recommended) section for details. +For the recommended approach, use the middleware handler. See the [Middleware Handler (Recommended)](#middleware-handler-recommended) section for details. ## Types ### Config -The configuration for the session middleware. +Defines the configuration options for the session middleware. ```go type Config struct { @@ -128,7 +137,7 @@ type Config struct { ### Middleware -The `Middleware` struct encapsulates the session middleware configuration and storage. It is created using the `New` or `NewWithStorage` function and used as a `fiber.Handler`. +The `Middleware` struct encapsulates the session middleware configuration and storage, created via `New` or `NewWithStore`. ```go type Middleware struct { @@ -138,7 +147,7 @@ type Middleware struct { ### Session -The `Session` struct is used to interact with session data. You can retrieve it from the `Middleware` using the `FromContext` method or from the `Store` using the `Get` method. +Represents a user session, accessible through `FromContext` or `Store.Get`. ```go type Session struct {} @@ -146,7 +155,7 @@ type Session struct {} ### Store -The `Store` struct is used to manage session data. It is created using the `NewWithStore` function or by calling the `Store` method on a `Middleware`. +Handles session data management and is created using `NewWithStore` or by accessing the `Store` method of a middleware instance. ```go type Store struct { @@ -172,8 +181,6 @@ func DefaultErrorHandler(c *fiber.Ctx, err error) ### Middleware Methods -Used to interact with session data when using the middleware handler. - ```go func (m *Middleware) Set(key string, value any) func (m *Middleware) Get(key string) any @@ -185,8 +192,6 @@ func (m *Middleware) Store() *Store ### Session Methods -If using the middleware handler, you generally won't need to use these methods directly. - ```go func (s *Session) Fresh() bool func (s *Session) ID() string @@ -194,6 +199,7 @@ func (s *Session) Get(key string) any func (s *Session) Set(key string, val any) func (s *Session) Destroy() error func (s *Session) Regenerate() error +func (s *Session) Release() func (s *Session) Reset() error func (s *Session) Save() error func (s *Session) Keys() []string @@ -213,34 +219,18 @@ func (s *Store) GetSessionByID(id string) (*Session, error) ## Examples :::note -**Security Note**: Fiber’s session middleware uses cookies with `SameSite=Lax` by default, which provides basic CSRF protection for most GET requests. However, for comprehensive security—especially for POST requests or sensitive operations (e.g., account changes, transactions, form submissions)—it is strongly recommended to use CSRF protection middleware. Fiber provides a `csrf` middleware that can be used in conjunction with the `session` middleware for robust protection. Find more information in the [CSRF Middleware](https://docs.gofiber.io/api/middleware/csrf) documentation. - -### Recommendations - -1. **Session Middleware Without CSRF**: - - You can use the `session` middleware without the `csrf` middleware or rely solely on `SameSite=Lax` for basic protection in low-risk scenarios. - -2. **Double Submit Cookie Pattern**: - - You can implement the **double submit cookie pattern** (via custom, third-party, or built-in middleware), where the CSRF token is stored in a cookie, and the request includes the token in a hidden field or header. In this approach, there is no need to pass the `session.Store` to the `csrf` middleware. Simply apply the `session.New()` and `csrf.New()` middleware to the routes you want to protect. - -3. **Recommended Approach**: - - For stronger protection, especially in high-risk scenarios, use the `csrf` middleware with the session store. This method implements the **Synchronizer Token Pattern**, providing robust defense by associating the CSRF token with the user’s session. This approach requires passing the `session.Store` to the `csrf` middleware. - - Ensure the CSRF token is embedded in forms or included in a header for POST requests and verified on the server side for incoming requests. This adds a crucial security layer for state-changing actions. - +**Security Notice**: For robust security, especially during sensitive operations like account changes or transactions, consider using CSRF protection. Fiber provides a [CSRF Middleware](https://docs.gofiber.io/api/middleware/csrf) that can be used with sessions to prevent CSRF attacks. ::: -### As a Middleware Handler (Recommended) +### Middleware Handler (Recommended) ```go package main import ( - "fmt" - "log" - "github.com/gofiber/fiber/v3" - "github.com/gofiber/session/v3" - "github.com/gofiber/session/v3/middleware/csrf" + "github.com/gofiber/fiber/v3/middleware/csrf" + "github.com/gofiber/fiber/v3/middleware/session" ) func main() { @@ -249,7 +239,6 @@ func main() { sessionMiddleware, sessionStore := session.NewWithStore() app.Use(sessionMiddleware) - app.Use(csrf.New(csrf.Config{ Store: sessionStore, })) @@ -265,27 +254,23 @@ func main() { return c.SendString("Welcome anonymous user!") } - return c.SendString(fmt.Sprintf("Welcome %v", name)) + return c.SendString("Welcome " + name) }) - log.Fatal(app.Listen(":3000")) + app.Listen(":3000") } ``` -### Using a Custom Storage - -This example shows how to use the `sqlite3` storage from the [Fiber storage package](https://github.com/gofiber/storage). +### Custom Storage Example ```go package main import ( - "log" - "github.com/gofiber/fiber/v3" "github.com/gofiber/storage/sqlite3" - "github.com/gofiber/session/v3" - "github.com/gofiber/session/v3/middleware/csrf" + "github.com/gofiber/fiber/v3/middleware/csrf" + "github.com/gofiber/fiber/v3/middleware/session" ) func main() { @@ -297,28 +282,23 @@ func main() { }) app.Use(sessionMiddleware) - app.Use(csrf.New(csrf.Config{ Store: sessionStore, })) - log.Fatal(app.Listen(":3000")) + app.Listen(":3000") } ``` ### Session Without Middleware Handler -This example shows how to work with sessions directly without the middleware handler. - ```go package main import ( - "log" - "github.com/gofiber/fiber/v3" - "github.com/gofiber/session/v3" - "github.com/gofiber/session/v3/middleware/csrf" + "github.com/gofiber/fiber/v3/middleware/csrf" + "github.com/gofiber/fiber/v3/middleware/session" ) func main() { @@ -341,16 +321,14 @@ func main() { return c.SendString("Welcome anonymous user!") } - return c.SendString(fmt.Sprintf("Welcome %v", name)) + return c.SendString("Welcome " + name) }) - log.Fatal(app.Listen(":3000 - -")) + app.Listen(":3000") } ``` -### Using Custom Types in Session Data +### Custom Types in Session Data Session data can only be of the following types by default: @@ -379,11 +357,8 @@ To support other types in session data, you can register custom types. Here is a package main import ( - "log" - "github.com/gofiber/fiber/v3" - "github.com/gofiber/session/v3" - "github.com/gofiber/session/v3/middleware/session" + "github.com/gofiber/fiber/v3/middleware/session" ) type User struct { @@ -392,40 +367,34 @@ type User struct { } func main() { - // Create a new Fiber app app := fiber.New() - // Initialize custom session config sessionMiddleware, sessionStore := session.NewWithStore() - - // Register custom type sessionStore.RegisterType(User{}) - // Use the session middleware app.Use(sessionMiddleware) - ... + app.Listen(":3000") } ``` ## Config -| Property | Type | Description | Default | -|:------------------------|:----------------|:--------------------------------------------------------------------------------------------------------------|:----------------------| -| Storage | `fiber.Storage` | Storage interface to store the session data. | `memory.New()` | -| Next | `func(c fiber.Ctx) bool` | Function to skip this middleware when returned true. | `nil` | -| Store | `*Store` | Defines the session store. | `nil` (Required) | -| ErrorHandler | `func(*fiber.Ctx, error)` | Function executed for errors. | `nil` | -| KeyGenerator | `func() string` | KeyGenerator generates the session key. | `utils.UUIDv4` | -| KeyLookup | `string` | KeyLookup is a string in the form of "`:`" that is used to extract session id from the request. | `"cookie:session_id"` | -| CookieDomain | `string` | Domain of the cookie. | `""` | -| CookiePath | `string` | Path of the cookie. | `""` | -| CookieSameSite | `string` | Value of SameSite cookie. | `"Lax"` | -| IdleTimeout | `time.Duration` | Allowed session idle duration. | `24 * time.Hour` | -| Expiration | `time.Duration` | Allowed session duration. | `24 * time.Hour` | -| CookieSecure | `bool` | Indicates if cookie is secure. | `false` | -| CookieHTTPOnly | `bool` | Indicates if cookie is HTTP only. | `false` | -| CookieSessionOnly | `bool` | Decides whether cookie should last for only the browser session. Ignores Expiration if set to true. | `false` | +| Property | Type | Description | Default | +|-----------------------|--------------------------------|--------------------------------------------------------------------------------------------|---------------------------| +| **Storage** | `fiber.Storage` | Defines where session data is stored. | `nil` (in-memory storage) | +| **Next** | `func(c fiber.Ctx) bool` | Function to skip this middleware under certain conditions. | `nil` | +| **ErrorHandler** | `func(c fiber.Ctx, err error)` | Custom error handler for session middleware errors. | `nil` | +| **KeyGenerator** | `func() string` | Function to generate session IDs. | `UUID()` | +| **KeyLookup** | `string` | Key used to store session ID in cookie or header. | `"cookie:session_id"` | +| **CookieDomain** | `string` | The domain scope of the session cookie. | `""` | +| **CookiePath** | `string` | The path scope of the session cookie. | `"/"` | +| **CookieSameSite** | `string` | The SameSite attribute of the session cookie. | `"Lax"` | +| **IdleTimeout** | `time.Duration` | Maximum duration of inactivity before session expires. | `0` (no idle timeout) | +| **Expiration** | `time.Duration` | Maximum session duration before expiration. | `24 * time.Hour` | +| **CookieSecure** | `bool` | Ensures session cookie is only sent over HTTPS. | `false` | +| **CookieHTTPOnly** | `bool` | Ensures session cookie is not accessible to JavaScript (HTTP only). | `true` | +| **CookieSessionOnly** | `bool` | Prevents session cookie from being saved after the session ends (cookie expires on close). | `false` | ## Default Config diff --git a/middleware/session/config_test.go b/middleware/session/config_test.go index 1a4825206f..171d5424f4 100644 --- a/middleware/session/config_test.go +++ b/middleware/session/config_test.go @@ -5,18 +5,18 @@ import ( "time" "github.com/gofiber/fiber/v3" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" ) func TestConfigDefault(t *testing.T) { // Test default config cfg := configDefault() - assert.Equal(t, 24*time.Hour, cfg.IdleTimeout) - assert.Equal(t, "cookie:session_id", cfg.KeyLookup) - assert.NotNil(t, cfg.KeyGenerator) - assert.Equal(t, SourceCookie, cfg.source) - assert.Equal(t, "session_id", cfg.sessionName) + require.Equal(t, 24*time.Hour, cfg.IdleTimeout) + require.Equal(t, "cookie:session_id", cfg.KeyLookup) + require.NotNil(t, cfg.KeyGenerator) + require.Equal(t, SourceCookie, cfg.source) + require.Equal(t, "session_id", cfg.sessionName) } func TestConfigDefaultWithCustomConfig(t *testing.T) { @@ -27,11 +27,11 @@ func TestConfigDefaultWithCustomConfig(t *testing.T) { KeyGenerator: func() string { return "custom_key" }, } cfg := configDefault(customConfig) - assert.Equal(t, 48*time.Hour, cfg.IdleTimeout) - assert.Equal(t, "header:custom_session_id", cfg.KeyLookup) - assert.NotNil(t, cfg.KeyGenerator) - assert.Equal(t, SourceHeader, cfg.source) - assert.Equal(t, "custom_session_id", cfg.sessionName) + require.Equal(t, 48*time.Hour, cfg.IdleTimeout) + require.Equal(t, "header:custom_session_id", cfg.KeyLookup) + require.NotNil(t, cfg.KeyGenerator) + require.Equal(t, SourceHeader, cfg.source) + require.Equal(t, "custom_session_id", cfg.sessionName) } func TestDefaultErrorHandler(t *testing.T) { @@ -43,17 +43,17 @@ func TestDefaultErrorHandler(t *testing.T) { // Test DefaultErrorHandler DefaultErrorHandler(&ctx, fiber.ErrInternalServerError) - assert.Equal(t, fiber.StatusInternalServerError, ctx.Response().StatusCode()) + require.Equal(t, fiber.StatusInternalServerError, ctx.Response().StatusCode()) } func TestInvalidKeyLookupFormat(t *testing.T) { - assert.PanicsWithValue(t, "[session] KeyLookup must in the form of :", func() { + require.PanicsWithValue(t, "[session] KeyLookup must in the form of :", func() { configDefault(Config{KeyLookup: "invalid_format"}) }) } func TestUnsupportedSource(t *testing.T) { - assert.PanicsWithValue(t, "[session] source is not supported", func() { + require.PanicsWithValue(t, "[session] source is not supported", func() { configDefault(Config{KeyLookup: "unsupported:session_id"}) }) } diff --git a/middleware/session/data_test.go b/middleware/session/data_test.go index 3a94ed4ad9..166e257e72 100644 --- a/middleware/session/data_test.go +++ b/middleware/session/data_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestKeys(t *testing.T) { @@ -11,40 +12,48 @@ func TestKeys(t *testing.T) { // Test case: Empty data t.Run("Empty data", func(t *testing.T) { + t.Parallel() d := acquireData() defer dataPool.Put(d) + d.Reset() // Ensure data is reset keys := d.Keys() - assert.Empty(t, keys, "Expected no keys in empty data") + require.Empty(t, keys, "Expected no keys in empty data") }) // Test case: Single key t.Run("Single key", func(t *testing.T) { + t.Parallel() d := acquireData() defer dataPool.Put(d) + d.Reset() // Ensure data is reset d.Set("key1", "value1") keys := d.Keys() - assert.Len(t, keys, 1, "Expected one key") - assert.Contains(t, keys, "key1", "Expected key1 to be present") + require.Len(t, keys, 1, "Expected one key") + require.Contains(t, keys, "key1", "Expected key1 to be present") }) // Test case: Multiple keys t.Run("Multiple keys", func(t *testing.T) { + t.Parallel() d := acquireData() defer dataPool.Put(d) + d.Reset() // Ensure data is reset d.Set("key1", "value1") d.Set("key2", "value2") d.Set("key3", "value3") keys := d.Keys() - assert.Len(t, keys, 3, "Expected three keys") - assert.Contains(t, keys, "key1", "Expected key1 to be present") - assert.Contains(t, keys, "key2", "Expected key2 to be present") - assert.Contains(t, keys, "key3", "Expected key3 to be present") + require.Len(t, keys, 3, "Expected three keys") + require.Contains(t, keys, "key1", "Expected key1 to be present") + require.Contains(t, keys, "key2", "Expected key2 to be present") + require.Contains(t, keys, "key3", "Expected key3 to be present") }) // Test case: Concurrent access t.Run("Concurrent access", func(t *testing.T) { + t.Parallel() d := acquireData() defer dataPool.Put(d) + d.Reset() // Ensure data is reset d.Set("key1", "value1") d.Set("key2", "value2") d.Set("key3", "value3") @@ -70,36 +79,44 @@ func TestData_Len(t *testing.T) { // Test case: Empty data t.Run("Empty data", func(t *testing.T) { + t.Parallel() d := acquireData() defer dataPool.Put(d) + d.Reset() // Ensure data is reset length := d.Len() - assert.Equal(t, 0, length, "Expected length to be 0 for empty data") + require.Equal(t, 0, length, "Expected length to be 0 for empty data") }) // Test case: Single key t.Run("Single key", func(t *testing.T) { + t.Parallel() d := acquireData() defer dataPool.Put(d) + d.Reset() // Ensure data is reset d.Set("key1", "value1") length := d.Len() - assert.Equal(t, 1, length, "Expected length to be 1 when one key is set") + require.Equal(t, 1, length, "Expected length to be 1 when one key is set") }) // Test case: Multiple keys t.Run("Multiple keys", func(t *testing.T) { + t.Parallel() d := acquireData() defer dataPool.Put(d) + d.Reset() // Ensure data is reset d.Set("key1", "value1") d.Set("key2", "value2") d.Set("key3", "value3") length := d.Len() - assert.Equal(t, 3, length, "Expected length to be 3 when three keys are set") + require.Equal(t, 3, length, "Expected length to be 3 when three keys are set") }) // Test case: Concurrent access t.Run("Concurrent access", func(t *testing.T) { + t.Parallel() d := acquireData() defer dataPool.Put(d) + d.Reset() // Ensure data is reset d.Set("key1", "value1") d.Set("key2", "value2") d.Set("key3", "value3") @@ -125,19 +142,23 @@ func TestData_Get(t *testing.T) { // Test case: Non-existent key t.Run("Non-existent key", func(t *testing.T) { + t.Parallel() d := acquireData() defer dataPool.Put(d) + d.Reset() // Ensure data is reset value := d.Get("non-existent-key") - assert.Nil(t, value, "Expected nil for non-existent key") + require.Nil(t, value, "Expected nil for non-existent key") }) // Test case: Existing key t.Run("Existing key", func(t *testing.T) { + t.Parallel() d := acquireData() defer dataPool.Put(d) + d.Reset() // Ensure data is reset d.Set("key1", "value1") value := d.Get("key1") - assert.Equal(t, "value1", value, "Expected value1 for key1") + require.Equal(t, "value1", value, "Expected value1 for key1") }) } @@ -146,12 +167,13 @@ func TestData_Reset(t *testing.T) { // Test case: Reset data t.Run("Reset data", func(t *testing.T) { + t.Parallel() d := acquireData() defer dataPool.Put(d) d.Set("key1", "value1") d.Set("key2", "value2") d.Reset() - assert.Empty(t, d.Data, "Expected data map to be empty after reset") + require.Empty(t, d.Data, "Expected data map to be empty after reset") }) } @@ -160,18 +182,22 @@ func TestData_Delete(t *testing.T) { // Test case: Delete existing key t.Run("Delete existing key", func(t *testing.T) { + t.Parallel() d := acquireData() defer dataPool.Put(d) + d.Reset() // Ensure data is reset d.Set("key1", "value1") d.Delete("key1") value := d.Get("key1") - assert.Nil(t, value, "Expected nil for deleted key") + require.Nil(t, value, "Expected nil for deleted key") }) // Test case: Delete non-existent key t.Run("Delete non-existent key", func(t *testing.T) { + t.Parallel() d := acquireData() defer dataPool.Put(d) + d.Reset() // Ensure data is reset d.Delete("non-existent-key") // No assertion needed, just ensure no panic or error }) diff --git a/middleware/session/session.go b/middleware/session/session.go index eae6db8cc2..ae5dbe3c43 100644 --- a/middleware/session/session.go +++ b/middleware/session/session.go @@ -58,9 +58,15 @@ func acquireSession() *Session { // // The session should not be used after calling this function. // +// Important: The Release function should only be used when accessing the session directly, +// for example, when you have called func (s *Session) Get(ctx) to get the session. +// It should not be used when using the session with a *Middleware handler in the request +// call stack, as the middleware will still need to access the session. +// // Usage: // -// s.Release() +// sess := session.Get(ctx) +// defer sess.Release() func (s *Session) Release() { if s == nil { return @@ -264,9 +270,7 @@ func (s *Session) refresh() { // Save updates the storage and client cookie. // // sess.Save() will save the session data to the storage and update the -// client cookie, and it will release the session after saving. -// -// It's not safe to use the session after calling Save(). +// client cookie. // // Returns: // - error: An error if the save operation fails. diff --git a/middleware/session/session_test.go b/middleware/session/session_test.go index 53516c96ce..e6cd98f765 100644 --- a/middleware/session/session_test.go +++ b/middleware/session/session_test.go @@ -12,6 +12,8 @@ import ( "github.com/valyala/fasthttp" ) +const testSessionID = "test-session-id" + // go test -run Test_Session func Test_Session(t *testing.T) { t.Parallel() @@ -963,3 +965,33 @@ func Test_Session_Concurrency(t *testing.T) { require.NoError(t, err) } } + +// go test -v race -run Test_Session_Release -count 4 +func Test_Session_Release(t *testing.T) { + t.Parallel() + + // session store + store := newStore() + // fiber instance + app := fiber.New() + // fiber context + ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(ctx) + + // acquire a new session + sess := acquireSession() + sess.ctx = ctx + sess.config = store + sess.id = testSessionID + sess.Set("key", "value") + + // release the session + sess.Release() + + // assertions + require.Empty(t, sess.id) + require.Nil(t, sess.ctx) + require.Nil(t, sess.config) + require.Empty(t, sess.Keys()) + require.Zero(t, sess.byteBuffer.Len()) +} diff --git a/middleware/session/store_test.go b/middleware/session/store_test.go index 4886b0603c..6f2e611ae8 100644 --- a/middleware/session/store_test.go +++ b/middleware/session/store_test.go @@ -5,7 +5,6 @@ import ( "testing" "github.com/gofiber/fiber/v3" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" ) @@ -13,7 +12,7 @@ import ( // go test -run Test_Store_getSessionID func Test_Store_getSessionID(t *testing.T) { t.Parallel() - expectedID := "test-session-id" + expectedID := testSessionID // fiber instance app := fiber.New() @@ -68,7 +67,7 @@ func Test_Store_getSessionID(t *testing.T) { func Test_Store_Get(t *testing.T) { // Regression: https://github.com/gofiber/fiber/security/advisories/GHSA-98j2-3j3p-fw2v t.Parallel() - unexpectedID := "test-session-id" + unexpectedID := testSessionID // fiber instance app := fiber.New() t.Run("session should be re-generated if it is invalid", func(t *testing.T) { @@ -136,8 +135,8 @@ func TestStore_Get_SessionAlreadyLoaded(t *testing.T) { sess, err := store.Get(ctx) // Assert that the error is ErrSessionAlreadyLoadedByMiddleware - assert.Nil(t, sess) - assert.Equal(t, ErrSessionAlreadyLoadedByMiddleware, err) + require.Nil(t, sess) + require.Equal(t, ErrSessionAlreadyLoadedByMiddleware, err) } func TestStore_Delete(t *testing.T) { @@ -146,13 +145,13 @@ func TestStore_Delete(t *testing.T) { t.Run("delete with empty session ID", func(t *testing.T) { err := store.Delete("") - assert.Error(t, err) - assert.Equal(t, ErrEmptySessionID, err) + require.Error(t, err) + require.Equal(t, ErrEmptySessionID, err) }) t.Run("delete non-existing session", func(t *testing.T) { err := store.Delete("non-existing-session-id") - assert.NoError(t, err) + require.NoError(t, err) }) } @@ -165,16 +164,16 @@ func Test_Store_GetSessionByID(t *testing.T) { t.Parallel() sess, err := store.GetSessionByID("") require.Error(t, err) - assert.Nil(t, sess) - assert.Equal(t, ErrEmptySessionID, err) + require.Nil(t, sess) + require.Equal(t, ErrEmptySessionID, err) }) t.Run("non-existent session ID", func(t *testing.T) { t.Parallel() sess, err := store.GetSessionByID("non-existent-session-id") require.Error(t, err) - assert.Nil(t, sess) - assert.Equal(t, ErrSessionIDNotFoundInStore, err) + require.Nil(t, sess) + require.Equal(t, ErrSessionIDNotFoundInStore, err) }) t.Run("valid session ID", func(t *testing.T) { @@ -194,7 +193,7 @@ func Test_Store_GetSessionByID(t *testing.T) { // Retrieve the session by ID retrievedSession, err := store.GetSessionByID(sessionID) require.NoError(t, err) - assert.NotNil(t, retrievedSession) - assert.Equal(t, sessionID, retrievedSession.ID()) + require.NotNil(t, retrievedSession) + require.Equal(t, sessionID, retrievedSession.ID()) }) } From b479895da2de7aa4f6dec5e1aeb2bf46bbd6a2ff Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Fri, 13 Sep 2024 16:32:26 -0300 Subject: [PATCH 35/79] refactor: session data locking in middleware/session/data.go --- middleware/session/data.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/middleware/session/data.go b/middleware/session/data.go index 81732babcc..7d278d8e21 100644 --- a/middleware/session/data.go +++ b/middleware/session/data.go @@ -99,12 +99,12 @@ func (d *data) Delete(key string) { // // keys := d.Keys() func (d *data) Keys() []string { - d.Lock() + d.RLock() keys := make([]string, 0, len(d.Data)) for k := range d.Data { keys = append(keys, k) } - d.Unlock() + d.RUnlock() return keys } @@ -117,5 +117,7 @@ func (d *data) Keys() []string { // // length := d.Len() func (d *data) Len() int { + d.RLock() + defer d.RUnlock() return len(d.Data) } From afab5806f705280bc15ca680074ecda49e4c71a2 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Fri, 13 Sep 2024 16:43:35 -0300 Subject: [PATCH 36/79] refactor(middleware/session): Add unit test for session middleware store --- middleware/session/middleware_test.go | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/middleware/session/middleware_test.go b/middleware/session/middleware_test.go index 6855ae9d8d..579d61c44c 100644 --- a/middleware/session/middleware_test.go +++ b/middleware/session/middleware_test.go @@ -441,3 +441,29 @@ func Test_Session_Next(t *testing.T) { h(ctx) require.Equal(t, fiber.StatusInternalServerError, ctx.Response.StatusCode()) } + +func Test_Session_Middleware_Store(t *testing.T) { + t.Parallel() + app := fiber.New() + + handler, sessionStore := NewWithStore() + + app.Use(handler) + + app.Get("/", func(c fiber.Ctx) error { + sess := FromContext(c) + st := sess.Store() + if st != sessionStore { + return c.SendStatus(fiber.StatusInternalServerError) + } + return c.SendStatus(fiber.StatusOK) + }) + + h := app.Handler() + + // Test GET request + ctx := &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(fiber.MethodGet) + h(ctx) + require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) +} From 6c0bf253ddbf8dd3620e02dcc7abcec25af269e4 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Fri, 13 Sep 2024 16:55:40 -0300 Subject: [PATCH 37/79] test: fix session_test.go and store_test.go unit tests --- middleware/session/session_test.go | 7 +++---- middleware/session/store_test.go | 4 ++-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/middleware/session/session_test.go b/middleware/session/session_test.go index e6cd98f765..c27a535fb7 100644 --- a/middleware/session/session_test.go +++ b/middleware/session/session_test.go @@ -8,12 +8,11 @@ import ( "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/internal/storage/memory" + "github.com/google/uuid" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" ) -const testSessionID = "test-session-id" - // go test -run Test_Session func Test_Session(t *testing.T) { t.Parallel() @@ -982,8 +981,8 @@ func Test_Session_Release(t *testing.T) { sess := acquireSession() sess.ctx = ctx sess.config = store - sess.id = testSessionID - sess.Set("key", "value") + rid, _ := uuid.NewRandom() + sess.id = rid.String() // release the session sess.Release() diff --git a/middleware/session/store_test.go b/middleware/session/store_test.go index 6f2e611ae8..06a744d138 100644 --- a/middleware/session/store_test.go +++ b/middleware/session/store_test.go @@ -12,7 +12,7 @@ import ( // go test -run Test_Store_getSessionID func Test_Store_getSessionID(t *testing.T) { t.Parallel() - expectedID := testSessionID + expectedID := "test-session-id" // fiber instance app := fiber.New() @@ -67,7 +67,7 @@ func Test_Store_getSessionID(t *testing.T) { func Test_Store_Get(t *testing.T) { // Regression: https://github.com/gofiber/fiber/security/advisories/GHSA-98j2-3j3p-fw2v t.Parallel() - unexpectedID := testSessionID + unexpectedID := "test-session-id" // fiber instance app := fiber.New() t.Run("session should be re-generated if it is invalid", func(t *testing.T) { From ad337f836fd113615896d9ded9de1039de51a7df Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Fri, 13 Sep 2024 17:09:48 -0300 Subject: [PATCH 38/79] refactor(docs): Update session.md with v3 changes to Expiration --- docs/middleware/session.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/middleware/session.md b/docs/middleware/session.md index 54ee031d64..86a3ec8824 100644 --- a/docs/middleware/session.md +++ b/docs/middleware/session.md @@ -39,6 +39,8 @@ As of v3, we recommend using the middleware handler for session management. Howe - **Session Lifecycle Management**: The `*Store.Save` method no longer releases the instance automatically. You must manually call `sess.Release()` after using the session to manage its lifecycle properly. +- **Expiration Handling**: Previously, the `Expiration` field represented the maximum session duration before expiration. However, it would extend every time the session was saved, making its behavior a mix between session duration and session idle timeout. The `Expiration` field has been removed and replaced with the `IdleTimeout` field, which explicitly defines the session's idle timeout period. Users who need to set a maximum session duration must now implement this logic themselves using data stored in the session. + For more details about Fiber v3, see [What’s New](https://github.com/gofiber/fiber/blob/main/docs/whats_new.md). ### Migrating v2 to v3 Example (Legacy Approach) @@ -108,6 +110,8 @@ app.Get("/", func(c *fiber.Ctx) error { ### v3 Example (Recommended Middleware Handler) +Do not call `sess.Release()` when using the middleware handler. `sess.Save()` is also not required, as the middleware automatically saves the session data. + For the recommended approach, use the middleware handler. See the [Middleware Handler (Recommended)](#middleware-handler-recommended) section for details. ## Types From 280d5399dc1b22938557bf7d27b73df13630fb7c Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Fri, 13 Sep 2024 21:00:23 -0300 Subject: [PATCH 39/79] refactor(middleware/session): Improve data pool handling and locking --- middleware/session/data.go | 20 ++++--- middleware/session/session_test.go | 85 +++++++++++++++++++----------- 2 files changed, 65 insertions(+), 40 deletions(-) diff --git a/middleware/session/data.go b/middleware/session/data.go index 7d278d8e21..93f7c06f57 100644 --- a/middleware/session/data.go +++ b/middleware/session/data.go @@ -29,7 +29,12 @@ var dataPool = sync.Pool{ // // d := acquireData() func acquireData() *data { - return dataPool.Get().(*data) //nolint:forcetypeassert // We store nothing else in the pool + obj := dataPool.Get() + if d, ok := obj.(*data); ok { + return d + } + // Handle unexpected type in the pool + panic("unexpected type in data pool") } // Reset clears the data map and resets the data object. @@ -39,8 +44,8 @@ func acquireData() *data { // d.Reset() func (d *data) Reset() { d.Lock() + defer d.Unlock() d.Data = make(map[string]any) - d.Unlock() } // Get retrieves a value from the data map by key. @@ -56,9 +61,8 @@ func (d *data) Reset() { // value := d.Get("key") func (d *data) Get(key string) any { d.RLock() - v := d.Data[key] - d.RUnlock() - return v + defer d.RUnlock() + return d.Data[key] } // Set updates or creates a new key-value pair in the data map. @@ -72,8 +76,8 @@ func (d *data) Get(key string) any { // d.Set("key", "value") func (d *data) Set(key string, value any) { d.Lock() + defer d.Unlock() d.Data[key] = value - d.Unlock() } // Delete removes a key-value pair from the data map. @@ -86,8 +90,8 @@ func (d *data) Set(key string, value any) { // d.Delete("key") func (d *data) Delete(key string) { d.Lock() + defer d.Unlock() delete(d.Data, key) - d.Unlock() } // Keys retrieves all keys in the data map. @@ -100,11 +104,11 @@ func (d *data) Delete(key string) { // keys := d.Keys() func (d *data) Keys() []string { d.RLock() + defer d.RUnlock() keys := make([]string, 0, len(d.Data)) for k := range d.Data { keys = append(keys, k) } - d.RUnlock() return keys } diff --git a/middleware/session/session_test.go b/middleware/session/session_test.go index c27a535fb7..ef8d469ade 100644 --- a/middleware/session/session_test.go +++ b/middleware/session/session_test.go @@ -8,7 +8,6 @@ import ( "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/internal/storage/memory" - "github.com/google/uuid" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" ) @@ -25,7 +24,6 @@ func Test_Session(t *testing.T) { // fiber context ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) - defer app.ReleaseCtx(ctx) // Get a new session sess, err := store.Get(ctx) @@ -34,6 +32,7 @@ func Test_Session(t *testing.T) { token := sess.ID() require.NoError(t, sess.Save()) + sess.Release() app.ReleaseCtx(ctx) ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) @@ -82,6 +81,9 @@ func Test_Session(t *testing.T) { err = sess.Save() require.NoError(t, err) + // release the session + sess.Release() + // release the context app.ReleaseCtx(ctx) // requesting entirely new context to prevent falsy tests @@ -94,6 +96,8 @@ func Test_Session(t *testing.T) { // this id should be randomly generated as session key was deleted require.Len(t, sess.ID(), 36) + sess.Release() + // when we use the original session for the second time // the session be should be same if the session is not expired app.ReleaseCtx(ctx) @@ -103,6 +107,7 @@ func Test_Session(t *testing.T) { // request the server with the old session ctx.Request().Header.SetCookie(store.sessionName, id) sess, err = store.Get(ctx) + defer sess.Release() require.NoError(t, err) require.False(t, sess.Fresh()) require.Equal(t, sess.id, id) @@ -187,6 +192,7 @@ func Test_Session_Types(t *testing.T) { err = sess.Save() require.NoError(t, err) + sess.Release() app.ReleaseCtx(ctx) ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) @@ -278,6 +284,8 @@ func Test_Session_Types(t *testing.T) { require.True(t, ok) require.Equal(t, vcomplex128, vcomplex128Result) + sess.Release() + app.ReleaseCtx(ctx) } @@ -305,6 +313,7 @@ func Test_Session_Store_Reset(t *testing.T) { require.NoError(t, store.Reset()) id := sess.ID() + sess.Release() app.ReleaseCtx(ctx) ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(ctx) @@ -312,6 +321,7 @@ func Test_Session_Store_Reset(t *testing.T) { // make sure the session is recreated sess, err = store.Get(ctx) + defer sess.Release() require.NoError(t, err) require.True(t, sess.Fresh()) require.Nil(t, sess.Get("hello")) @@ -339,6 +349,7 @@ func Test_Session_Save(t *testing.T) { // save session err = sess.Save() require.NoError(t, err) + sess.Release() }) t.Run("save to header", func(t *testing.T) { @@ -364,6 +375,7 @@ func Test_Session_Save(t *testing.T) { require.NoError(t, err) require.Equal(t, store.getSessionID(ctx), string(ctx.Response().Header.Peek(store.sessionName))) require.Equal(t, store.getSessionID(ctx), string(ctx.Request().Header.Peek(store.sessionName))) + sess.Release() }) } @@ -398,6 +410,7 @@ func Test_Session_Save_Expiration(t *testing.T) { err = sess.Save() require.NoError(t, err) + sess.Release() app.ReleaseCtx(ctx) ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) @@ -410,6 +423,8 @@ func Test_Session_Save_Expiration(t *testing.T) { // just to make sure the session has been expired time.Sleep(sessionDuration + (10 * time.Millisecond)) + sess.Release() + app.ReleaseCtx(ctx) ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(ctx) @@ -417,6 +432,7 @@ func Test_Session_Save_Expiration(t *testing.T) { // here you should get a new session ctx.Request().Header.SetCookie(store.sessionName, token) sess, err = store.Get(ctx) + defer sess.Release() require.NoError(t, err) require.Nil(t, sess.Get("name")) require.NotEqual(t, sess.ID(), token) @@ -439,6 +455,7 @@ func Test_Session_Destroy(t *testing.T) { // get session sess, err := store.Get(ctx) + defer sess.Release() require.NoError(t, err) sess.Set("name", "fenny") @@ -468,6 +485,7 @@ func Test_Session_Destroy(t *testing.T) { id := sess.ID() require.NoError(t, sess.Save()) + sess.Release() app.ReleaseCtx(ctx) ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(ctx) @@ -476,6 +494,7 @@ func Test_Session_Destroy(t *testing.T) { ctx.Request().Header.Set(store.sessionName, id) sess, err = store.Get(ctx) require.NoError(t, err) + defer sess.Release() err = sess.Destroy() require.NoError(t, err) @@ -512,6 +531,8 @@ func Test_Session_Cookie(t *testing.T) { require.NoError(t, err) require.NoError(t, sess.Save()) + sess.Release() + // cookie should be set on Save ( even if empty data ) require.Len(t, ctx.Response().Header.PeekCookie(store.sessionName), 84) } @@ -535,8 +556,11 @@ func Test_Session_Cookie_In_Middleware_Chain(t *testing.T) { id := sess.ID() require.NoError(t, sess.Save()) + sess.Release() + sess, err = store.Get(ctx) require.NoError(t, err) + defer sess.Release() sess.Set("name", "john") require.True(t, sess.Fresh()) require.Equal(t, id, sess.ID()) // session id should be the same @@ -560,6 +584,7 @@ func Test_Session_Deletes_Single_Key(t *testing.T) { sess.Set("id", "1") require.NoError(t, sess.Save()) + sess.Release() app.ReleaseCtx(ctx) ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) ctx.Request().Header.SetCookie(store.sessionName, id) @@ -569,11 +594,13 @@ func Test_Session_Deletes_Single_Key(t *testing.T) { sess.Delete("id") require.NoError(t, sess.Save()) + sess.Release() app.ReleaseCtx(ctx) ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) ctx.Request().Header.SetCookie(store.sessionName, id) sess, err = store.Get(ctx) + defer sess.Release() require.NoError(t, err) require.False(t, sess.Fresh()) require.Nil(t, sess.Get("id")) @@ -610,6 +637,7 @@ func Test_Session_Reset(t *testing.T) { err = freshSession.Save() require.NoError(t, err) + freshSession.Release() app.ReleaseCtx(ctx) ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) @@ -642,6 +670,8 @@ func Test_Session_Reset(t *testing.T) { err = acquiredSession.Save() require.NoError(t, err) + acquiredSession.Release() + // Check that the session id is not in the header or cookie anymore require.Equal(t, "", string(ctx.Response().Header.Peek(store.sessionName))) require.Equal(t, "", string(ctx.Request().Header.Peek(store.sessionName))) @@ -675,6 +705,8 @@ func Test_Session_Regenerate(t *testing.T) { err = freshSession.Save() require.NoError(t, err) + freshSession.Release() + // release the context app.ReleaseCtx(ctx) @@ -687,6 +719,7 @@ func Test_Session_Regenerate(t *testing.T) { // as the session is in the storage, session.fresh should be false acquiredSession, err := store.Get(ctx) require.NoError(t, err) + defer acquiredSession.Release() require.False(t, acquiredSession.Fresh()) err = acquiredSession.Regenerate() @@ -716,6 +749,8 @@ func Benchmark_Session(b *testing.B) { sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark sess.Set("john", "doe") _ = sess.Save() //nolint:errcheck // We're inside a benchmark + + sess.Release() } }) @@ -734,6 +769,8 @@ func Benchmark_Session(b *testing.B) { sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark sess.Set("john", "doe") _ = sess.Save() //nolint:errcheck // We're inside a benchmark + + sess.Release() } }) } @@ -752,6 +789,9 @@ func Benchmark_Session_Parallel(b *testing.B) { sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark sess.Set("john", "doe") _ = sess.Save() //nolint:errcheck // We're inside a benchmark + + sess.Release() + app.ReleaseCtx(c) } }) @@ -772,6 +812,9 @@ func Benchmark_Session_Parallel(b *testing.B) { sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark sess.Set("john", "doe") _ = sess.Save() //nolint:errcheck // We're inside a benchmark + + sess.Release() + app.ReleaseCtx(c) } }) @@ -794,6 +837,7 @@ func Benchmark_Session_Asserted(b *testing.B) { sess.Set("john", "doe") err = sess.Save() require.NoError(b, err) + sess.Release() } }) @@ -814,6 +858,7 @@ func Benchmark_Session_Asserted(b *testing.B) { sess.Set("john", "doe") err = sess.Save() require.NoError(b, err) + sess.Release() } }) } @@ -833,6 +878,7 @@ func Benchmark_Session_Asserted_Parallel(b *testing.B) { require.NoError(b, err) sess.Set("john", "doe") require.NoError(b, sess.Save()) + sess.Release() app.ReleaseCtx(c) } }) @@ -854,6 +900,7 @@ func Benchmark_Session_Asserted_Parallel(b *testing.B) { require.NoError(b, err) sess.Set("john", "doe") require.NoError(b, sess.Save()) + sess.Release() app.ReleaseCtx(c) } }) @@ -902,6 +949,9 @@ func Test_Session_Concurrency(t *testing.T) { return } + // release the session + sess.Release() + // Release the context app.ReleaseCtx(localCtx) @@ -918,6 +968,7 @@ func Test_Session_Concurrency(t *testing.T) { errChan <- err return } + defer sess.Release() // Get the value name := sess.Get("name") @@ -964,33 +1015,3 @@ func Test_Session_Concurrency(t *testing.T) { require.NoError(t, err) } } - -// go test -v race -run Test_Session_Release -count 4 -func Test_Session_Release(t *testing.T) { - t.Parallel() - - // session store - store := newStore() - // fiber instance - app := fiber.New() - // fiber context - ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) - defer app.ReleaseCtx(ctx) - - // acquire a new session - sess := acquireSession() - sess.ctx = ctx - sess.config = store - rid, _ := uuid.NewRandom() - sess.id = rid.String() - - // release the session - sess.Release() - - // assertions - require.Empty(t, sess.id) - require.Nil(t, sess.ctx) - require.Nil(t, sess.config) - require.Empty(t, sess.Keys()) - require.Zero(t, sess.byteBuffer.Len()) -} From 40da2c04f0b067f7be77b3e7e022a94b7c30a370 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Fri, 13 Sep 2024 21:32:46 -0300 Subject: [PATCH 40/79] chore(middleware/session): TODO for Expiration field in session config --- middleware/session/config.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/middleware/session/config.go b/middleware/session/config.go index ed3e557b3b..9194408bc1 100644 --- a/middleware/session/config.go +++ b/middleware/session/config.go @@ -62,9 +62,10 @@ type Config struct { // Optional. Default value 24 * time.Hour IdleTimeout time.Duration - // Allowed session duration - // Optional. Default value 24 * time.Hour - Expiration time.Duration + // TODO: Implement this, or remove and leave it to the user to implement + // // Allowed session duration + // // Optional. Default value 24 * time.Hour + // Expiration time.Duration // Indicates if cookie is secure. // Optional. Default value false. From 3ad4bc938922ab815a5c33818defe89ecf2d919b Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Fri, 13 Sep 2024 21:50:18 -0300 Subject: [PATCH 41/79] refactor(middleware/session): Improve session data pool handling and locking --- middleware/session/session.go | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/middleware/session/session.go b/middleware/session/session.go index ae5dbe3c43..1c5b22f67a 100644 --- a/middleware/session/session.go +++ b/middleware/session/session.go @@ -40,6 +40,8 @@ var sessionPool = sync.Pool{ // s := acquireSession() func acquireSession() *Session { s := sessionPool.Get().(*Session) //nolint:forcetypeassert,errcheck // We store nothing else in the pool + s.mu.Lock() + defer s.mu.Unlock() if s.data == nil { s.data = acquireData() } @@ -76,6 +78,7 @@ func (s *Session) Release() { func releaseSession(s *Session) { s.mu.Lock() + defer s.mu.Unlock() s.id = "" s.idleTimeout = 0 s.ctx = nil @@ -86,7 +89,6 @@ func releaseSession(s *Session) { if s.byteBuffer != nil { s.byteBuffer.Reset() } - s.mu.Unlock() sessionPool.Put(s) } @@ -295,6 +297,7 @@ func (s *Session) saveSession() error { } s.mu.Lock() + defer s.mu.Unlock() // Check if session has your own expiration, otherwise use default value if s.idleTimeout <= 0 { @@ -316,13 +319,7 @@ func (s *Session) saveSession() error { copy(encodedBytes, s.byteBuffer.Bytes()) // Pass copied bytes with session id to provider - if err := s.config.Storage.Set(s.id, encodedBytes, s.idleTimeout); err != nil { - return err - } - - s.mu.Unlock() - - return nil + return s.config.Storage.Set(s.id, encodedBytes, s.idleTimeout) } // Keys retrieves all keys in the current session. From ffac824f50c043ba47a899d7e540eb9ed7f5a598 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Fri, 13 Sep 2024 22:14:17 -0300 Subject: [PATCH 42/79] refactor(middleware/session): Improve session data pool handling and locking --- middleware/session/data_test.go | 24 ++++++++++++------------ middleware/session/session.go | 2 -- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/middleware/session/data_test.go b/middleware/session/data_test.go index 166e257e72..7b8b0787b5 100644 --- a/middleware/session/data_test.go +++ b/middleware/session/data_test.go @@ -15,7 +15,7 @@ func TestKeys(t *testing.T) { t.Parallel() d := acquireData() defer dataPool.Put(d) - d.Reset() // Ensure data is reset + defer d.Reset() keys := d.Keys() require.Empty(t, keys, "Expected no keys in empty data") }) @@ -25,7 +25,7 @@ func TestKeys(t *testing.T) { t.Parallel() d := acquireData() defer dataPool.Put(d) - d.Reset() // Ensure data is reset + defer d.Reset() d.Set("key1", "value1") keys := d.Keys() require.Len(t, keys, 1, "Expected one key") @@ -37,7 +37,7 @@ func TestKeys(t *testing.T) { t.Parallel() d := acquireData() defer dataPool.Put(d) - d.Reset() // Ensure data is reset + defer d.Reset() d.Set("key1", "value1") d.Set("key2", "value2") d.Set("key3", "value3") @@ -53,7 +53,7 @@ func TestKeys(t *testing.T) { t.Parallel() d := acquireData() defer dataPool.Put(d) - d.Reset() // Ensure data is reset + defer d.Reset() d.Set("key1", "value1") d.Set("key2", "value2") d.Set("key3", "value3") @@ -82,7 +82,7 @@ func TestData_Len(t *testing.T) { t.Parallel() d := acquireData() defer dataPool.Put(d) - d.Reset() // Ensure data is reset + defer d.Reset() length := d.Len() require.Equal(t, 0, length, "Expected length to be 0 for empty data") }) @@ -92,7 +92,7 @@ func TestData_Len(t *testing.T) { t.Parallel() d := acquireData() defer dataPool.Put(d) - d.Reset() // Ensure data is reset + defer d.Reset() d.Set("key1", "value1") length := d.Len() require.Equal(t, 1, length, "Expected length to be 1 when one key is set") @@ -103,7 +103,7 @@ func TestData_Len(t *testing.T) { t.Parallel() d := acquireData() defer dataPool.Put(d) - d.Reset() // Ensure data is reset + defer d.Reset() d.Set("key1", "value1") d.Set("key2", "value2") d.Set("key3", "value3") @@ -116,7 +116,7 @@ func TestData_Len(t *testing.T) { t.Parallel() d := acquireData() defer dataPool.Put(d) - d.Reset() // Ensure data is reset + defer d.Reset() d.Set("key1", "value1") d.Set("key2", "value2") d.Set("key3", "value3") @@ -145,7 +145,7 @@ func TestData_Get(t *testing.T) { t.Parallel() d := acquireData() defer dataPool.Put(d) - d.Reset() // Ensure data is reset + defer d.Reset() value := d.Get("non-existent-key") require.Nil(t, value, "Expected nil for non-existent key") }) @@ -155,7 +155,7 @@ func TestData_Get(t *testing.T) { t.Parallel() d := acquireData() defer dataPool.Put(d) - d.Reset() // Ensure data is reset + defer d.Reset() d.Set("key1", "value1") value := d.Get("key1") require.Equal(t, "value1", value, "Expected value1 for key1") @@ -185,7 +185,7 @@ func TestData_Delete(t *testing.T) { t.Parallel() d := acquireData() defer dataPool.Put(d) - d.Reset() // Ensure data is reset + defer d.Reset() d.Set("key1", "value1") d.Delete("key1") value := d.Get("key1") @@ -197,7 +197,7 @@ func TestData_Delete(t *testing.T) { t.Parallel() d := acquireData() defer dataPool.Put(d) - d.Reset() // Ensure data is reset + defer d.Reset() d.Delete("non-existent-key") // No assertion needed, just ensure no panic or error }) diff --git a/middleware/session/session.go b/middleware/session/session.go index 1c5b22f67a..7bff93d461 100644 --- a/middleware/session/session.go +++ b/middleware/session/session.go @@ -40,8 +40,6 @@ var sessionPool = sync.Pool{ // s := acquireSession() func acquireSession() *Session { s := sessionPool.Get().(*Session) //nolint:forcetypeassert,errcheck // We store nothing else in the pool - s.mu.Lock() - defer s.mu.Unlock() if s.data == nil { s.data = acquireData() } From 9f8c2d714d63c1e23950d32e1570f00e413d46ae Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Thu, 19 Sep 2024 13:07:49 -0300 Subject: [PATCH 43/79] test(middleware/csrf): add session middleware coverage --- middleware/csrf/csrf_test.go | 55 +++++++++++++++++++++++++++++++++++ middleware/session/session.go | 7 +---- 2 files changed, 56 insertions(+), 6 deletions(-) diff --git a/middleware/csrf/csrf_test.go b/middleware/csrf/csrf_test.go index 77ecb12e9b..133752abba 100644 --- a/middleware/csrf/csrf_test.go +++ b/middleware/csrf/csrf_test.go @@ -156,6 +156,61 @@ func Test_CSRF_WithSession(t *testing.T) { } } +// go test -run Test_CSRF_WithSession_Middleware +func Test_CSRF_WithSession_Middleware(t *testing.T) { + t.Parallel() + app := fiber.New() + + // session mw + smh, sstore := session.NewWithStore() + + // csrf mw + cmh := New(Config{ + Session: sstore, + }) + + app.Use(smh) + + app.Use(cmh) + + app.Get("/", func(c fiber.Ctx) error { + sess := session.FromContext(c) + sess.Set("hello", "world") + return c.SendStatus(fiber.StatusOK) + }) + + app.Post("/", func(c fiber.Ctx) error { + sess := session.FromContext(c) + if sess.Get("hello") != "world" { + return c.SendStatus(fiber.StatusInternalServerError) + } + return c.SendStatus(fiber.StatusOK) + }) + + h := app.Handler() + ctx := &fasthttp.RequestCtx{} + + // Generate CSRF token and session_id + ctx.Request.Header.SetMethod(fiber.MethodGet) + h(ctx) + csrfTokenParts := strings.Split(string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)), ";") + require.Greater(t, len(csrfTokenParts), 2) + csrfToken := strings.Split(csrfTokenParts[0], "=")[1] + require.NotEmpty(t, csrfToken) + sessionID := strings.Split(csrfTokenParts[1], "=")[1] + require.NotEmpty(t, sessionID) + + // Use the CSRF token and session_id + ctx.Request.Reset() + ctx.Response.Reset() + ctx.Request.Header.SetMethod(fiber.MethodPost) + ctx.Request.Header.Set(HeaderName, csrfToken) + ctx.Request.Header.SetCookie(ConfigDefault.CookieName, csrfToken) + ctx.Request.Header.SetCookie("session_id", sessionID) + h(ctx) + require.Equal(t, 200, ctx.Response.StatusCode()) +} + // go test -run Test_CSRF_ExpiredToken func Test_CSRF_ExpiredToken(t *testing.T) { t.Parallel() diff --git a/middleware/session/session.go b/middleware/session/session.go index 7bff93d461..07043b1aa5 100644 --- a/middleware/session/session.go +++ b/middleware/session/session.go @@ -76,7 +76,6 @@ func (s *Session) Release() { func releaseSession(s *Session) { s.mu.Lock() - defer s.mu.Unlock() s.id = "" s.idleTimeout = 0 s.ctx = nil @@ -87,6 +86,7 @@ func releaseSession(s *Session) { if s.byteBuffer != nil { s.byteBuffer.Reset() } + s.mu.Unlock() sessionPool.Put(s) } @@ -130,7 +130,6 @@ func (s *Session) ID() string { // // value := s.Get("key") func (s *Session) Get(key string) any { - // Better safe than sorry if s.data == nil { return nil } @@ -147,7 +146,6 @@ func (s *Session) Get(key string) any { // // s.Set("key", "value") func (s *Session) Set(key string, val any) { - // Better safe than sorry if s.data == nil { return } @@ -163,7 +161,6 @@ func (s *Session) Set(key string, val any) { // // s.Delete("key") func (s *Session) Delete(key string) { - // Better safe than sorry if s.data == nil { return } @@ -179,7 +176,6 @@ func (s *Session) Delete(key string) { // // err := s.Destroy() func (s *Session) Destroy() error { - // Better safe than sorry if s.data == nil { return nil } @@ -289,7 +285,6 @@ func (s *Session) Save() error { } func (s *Session) saveSession() error { - // Better safe than sorry if s.data == nil { return nil } From ecac9ce65c8ce2e2962afe82a9416f1f7abf9538 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Thu, 19 Sep 2024 13:14:27 -0300 Subject: [PATCH 44/79] chroe(middleware/session): TODO for unregistered session middleware --- middleware/session/middleware.go | 1 + 1 file changed, 1 insertion(+) diff --git a/middleware/session/middleware.go b/middleware/session/middleware.go index 2da0f429f4..77597ba776 100644 --- a/middleware/session/middleware.go +++ b/middleware/session/middleware.go @@ -180,6 +180,7 @@ func releaseMiddleware(m *Middleware) { func FromContext(c fiber.Ctx) *Middleware { m, ok := c.Locals(key).(*Middleware) if !ok { + // TODO: since this may be called we may not want to log this except in debug mode? log.Warn("session: Session middleware not registered. See https://docs.gofiber.io/middleware/session") return nil } From e27208256ee596233afdbae2786a6c8da9777e4c Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Thu, 19 Sep 2024 13:35:58 -0300 Subject: [PATCH 45/79] refactor(middleware/session): Update session middleware for v3 changes --- docs/whats_new.md | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/docs/whats_new.md b/docs/whats_new.md index 963d1daece..b51288b4ab 100644 --- a/docs/whats_new.md +++ b/docs/whats_new.md @@ -314,9 +314,17 @@ Added support for specifying Key length when using `encryptcookie.GenerateKey(le ### Session -:::caution -DRAFT section -::: +The Session middleware has undergone key changes in v3 to improve functionality and flexibility. While v2 methods remain available for backward compatibility, we now recommend using the new middleware handler for session management. + +#### Key Updates: + +- **New Middleware Handler**: The `New` function now returns a middleware handler instead of a `*Store`. To access the session store, use the `Store` method on the middleware, or opt for `NewWithStore` for custom store integration. + +- **Manual Session Release**: Session instances are no longer automatically released after being saved. To ensure proper lifecycle management, you must manually call `sess.Release()`. + +- **Idle Timeout**: The `Expiration` field has been replaced with `IdleTimeout`, which strictly handles session inactivity. If you require a maximum session duration, you'll need to implement it within your own session data. + +For more details on these changes and migration instructions, check the [Session Middleware Migration Guide](./middleware/session.md#migration-guide). ### Filesystem From b262a082a69b6788a6a5f9ec21729eaf1f1cd830 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Thu, 19 Sep 2024 13:37:19 -0300 Subject: [PATCH 46/79] refactor(middleware/session): Update session middleware for v3 changes --- docs/whats_new.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/whats_new.md b/docs/whats_new.md index b51288b4ab..4cb6b20b18 100644 --- a/docs/whats_new.md +++ b/docs/whats_new.md @@ -314,9 +314,9 @@ Added support for specifying Key length when using `encryptcookie.GenerateKey(le ### Session -The Session middleware has undergone key changes in v3 to improve functionality and flexibility. While v2 methods remain available for backward compatibility, we now recommend using the new middleware handler for session management. +The Session middleware has undergone key changes in v3 to improve functionality and flexibility. While v2 methods remain available for backward compatibility, we now recommend using the new middleware handler for session management. -#### Key Updates: +#### Key Updates - **New Middleware Handler**: The `New` function now returns a middleware handler instead of a `*Store`. To access the session store, use the `Store` method on the middleware, or opt for `NewWithStore` for custom store integration. From 9ec2b30312c9fb4f65b876d4906936461d789296 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Fri, 20 Sep 2024 13:56:32 -0300 Subject: [PATCH 47/79] refactor(middleware/session): Update session middleware idle timeout - Update the default idle timeout for session middleware from 24 hours to 30 minutes. - Add a note in the session middleware documentation about the importance of the middleware order. --- docs/middleware/session.md | 9 ++++++--- middleware/session/config.go | 11 +++-------- middleware/session/middleware.go | 15 +++++---------- middleware/session/session.go | 4 ++-- 4 files changed, 16 insertions(+), 23 deletions(-) diff --git a/docs/middleware/session.md b/docs/middleware/session.md index 86a3ec8824..4e38426907 100644 --- a/docs/middleware/session.md +++ b/docs/middleware/session.md @@ -226,6 +226,10 @@ func (s *Store) GetSessionByID(id string) (*Session, error) **Security Notice**: For robust security, especially during sensitive operations like account changes or transactions, consider using CSRF protection. Fiber provides a [CSRF Middleware](https://docs.gofiber.io/api/middleware/csrf) that can be used with sessions to prevent CSRF attacks. ::: +:::note +**Middleware Order**: The order of middleware matters. The session middleware should come before any handler or middleware that uses the session (for example, the CSRF middleware). +::: + ### Middleware Handler (Recommended) ```go @@ -395,7 +399,7 @@ func main() { | **CookiePath** | `string` | The path scope of the session cookie. | `"/"` | | **CookieSameSite** | `string` | The SameSite attribute of the session cookie. | `"Lax"` | | **IdleTimeout** | `time.Duration` | Maximum duration of inactivity before session expires. | `0` (no idle timeout) | -| **Expiration** | `time.Duration` | Maximum session duration before expiration. | `24 * time.Hour` | +| **Expiration** | `time.Duration` | Maximum session duration before expiration. | `30 * time.Minute` | | **CookieSecure** | `bool` | Ensures session cookie is only sent over HTTPS. | `false` | | **CookieHTTPOnly** | `bool` | Ensures session cookie is not accessible to JavaScript (HTTP only). | `true` | | **CookieSessionOnly** | `bool` | Prevents session cookie from being saved after the session ends (cookie expires on close). | `false` | @@ -413,8 +417,7 @@ session.Config{ CookieDomain: "", CookiePath: "", CookieSameSite: "Lax", - IdleTimeout: 24 * time.Hour, - Expiration: 24 * time.Hour, + IdleTimeout: 30 * time.Minute, CookieSecure: false, CookieHTTPOnly: false, CookieSessionOnly: false, diff --git a/middleware/session/config.go b/middleware/session/config.go index 9194408bc1..b6dae196f8 100644 --- a/middleware/session/config.go +++ b/middleware/session/config.go @@ -59,14 +59,9 @@ type Config struct { sessionName string // Allowed session idle duration - // Optional. Default value 24 * time.Hour + // Optional. Default value 30 * time.Minute. IdleTimeout time.Duration - // TODO: Implement this, or remove and leave it to the user to implement - // // Allowed session duration - // // Optional. Default value 24 * time.Hour - // Expiration time.Duration - // Indicates if cookie is secure. // Optional. Default value false. CookieSecure bool @@ -76,7 +71,7 @@ type Config struct { CookieHTTPOnly bool // Decides whether cookie should last for only the browser session. - // Ignores Expiration if set to true + // Ignores IdleTimeout if set to true // Optional. Default value false. CookieSessionOnly bool } @@ -91,7 +86,7 @@ const ( // ConfigDefault is the default config var ConfigDefault = Config{ - IdleTimeout: 24 * time.Hour, + IdleTimeout: 30 * time.Minute, KeyLookup: "cookie:session_id", KeyGenerator: utils.UUIDv4, source: "cookie", diff --git a/middleware/session/middleware.go b/middleware/session/middleware.go index 77597ba776..51ca5d32c9 100644 --- a/middleware/session/middleware.go +++ b/middleware/session/middleware.go @@ -10,12 +10,11 @@ import ( // Middleware defines the session middleware configuration type Middleware struct { - Session *Session - ctx *fiber.Ctx - config Config - mu sync.RWMutex - hasChanged bool // TODO: use this to optimize interaction with the session store - destroyed bool + Session *Session + ctx *fiber.Ctx + config Config + mu sync.RWMutex + destroyed bool } // key for looking up session middleware in request context @@ -161,7 +160,6 @@ func releaseMiddleware(m *Middleware) { m.Session = nil m.ctx = nil m.destroyed = false - m.hasChanged = false m.mu.Unlock() middlewarePool.Put(m) } @@ -201,7 +199,6 @@ func (m *Middleware) Set(key string, value any) { defer m.mu.Unlock() m.Session.Set(key, value) - m.hasChanged = true } // Get retrieves a value from the session by key. @@ -235,7 +232,6 @@ func (m *Middleware) Delete(key string) { defer m.mu.Unlock() m.Session.Delete(key) - m.hasChanged = true } // Destroy destroys the session. @@ -292,7 +288,6 @@ func (m *Middleware) Reset() error { defer m.mu.Unlock() err := m.Session.Reset() - m.hasChanged = true return err } diff --git a/middleware/session/session.go b/middleware/session/session.go index 07043b1aa5..1f739209fe 100644 --- a/middleware/session/session.go +++ b/middleware/session/session.go @@ -292,7 +292,7 @@ func (s *Session) saveSession() error { s.mu.Lock() defer s.mu.Unlock() - // Check if session has your own expiration, otherwise use default value + // Check is the session has an idle timeout if s.idleTimeout <= 0 { s.idleTimeout = s.config.IdleTimeout } @@ -330,7 +330,7 @@ func (s *Session) Keys() []string { return s.data.Keys() } -// SetIdleTimeout sets a specific expiration for this session. +// SetIdleTimeout sets a specific idle timeout for the session. // // Parameters: // - idleTimeout: The duration for the idle timeout. From 684dc8a727c7de2568f0914dc8b57587e623c203 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Fri, 20 Sep 2024 14:02:12 -0300 Subject: [PATCH 48/79] docws(middleware/session): Add note about IdleTimeout requiring save using legacy approach --- docs/middleware/session.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/middleware/session.md b/docs/middleware/session.md index 4e38426907..b5ce61b77a 100644 --- a/docs/middleware/session.md +++ b/docs/middleware/session.md @@ -51,6 +51,10 @@ To convert a v2 example to use the v3 legacy approach, follow these steps: 2. **Retrieve Session**: Access the session store using the `store.Get(c)` method. 3. **Release Session**: Ensure that you call `sess.Release()` after you are done with the session to manage its lifecycle. +:::note +When using the legacy approach, the IdleTimeout will only be updated when the session is saved. +::: + #### Example Conversion **v2 Example:** From 05d30a4cc0b47de8f62923bb586d6c50cea5f9ae Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Fri, 20 Sep 2024 14:52:11 -0300 Subject: [PATCH 49/79] refactor(middleware/session): Update session middleware idle timeout Update the idle timeout for the session middleware to 30 minutes. This ensures that the session expires after a period of inactivity. The previous value was 24 hours, which is too long for most use cases. This change improves the security and efficiency of the session management. --- docs/middleware/session.md | 3 +-- middleware/session/config_test.go | 2 +- middleware/session/session.go | 10 +++++++--- middleware/session/session_test.go | 4 +++- 4 files changed, 12 insertions(+), 7 deletions(-) diff --git a/docs/middleware/session.md b/docs/middleware/session.md index b5ce61b77a..96cdb47169 100644 --- a/docs/middleware/session.md +++ b/docs/middleware/session.md @@ -402,8 +402,7 @@ func main() { | **CookieDomain** | `string` | The domain scope of the session cookie. | `""` | | **CookiePath** | `string` | The path scope of the session cookie. | `"/"` | | **CookieSameSite** | `string` | The SameSite attribute of the session cookie. | `"Lax"` | -| **IdleTimeout** | `time.Duration` | Maximum duration of inactivity before session expires. | `0` (no idle timeout) | -| **Expiration** | `time.Duration` | Maximum session duration before expiration. | `30 * time.Minute` | +| **IdleTimeout** | `time.Duration` | Maximum duration of inactivity before session expires. | `30 * time.Minute` | | **CookieSecure** | `bool` | Ensures session cookie is only sent over HTTPS. | `false` | | **CookieHTTPOnly** | `bool` | Ensures session cookie is not accessible to JavaScript (HTTP only). | `true` | | **CookieSessionOnly** | `bool` | Prevents session cookie from being saved after the session ends (cookie expires on close). | `false` | diff --git a/middleware/session/config_test.go b/middleware/session/config_test.go index 171d5424f4..45456f6ea0 100644 --- a/middleware/session/config_test.go +++ b/middleware/session/config_test.go @@ -12,7 +12,7 @@ import ( func TestConfigDefault(t *testing.T) { // Test default config cfg := configDefault() - require.Equal(t, 24*time.Hour, cfg.IdleTimeout) + require.Equal(t, 30*time.Minute, cfg.IdleTimeout) require.Equal(t, "cookie:session_id", cfg.KeyLookup) require.NotNil(t, cfg.KeyGenerator) require.Equal(t, SourceCookie, cfg.source) diff --git a/middleware/session/session.go b/middleware/session/session.go index 1f739209fe..7cb7a59d9f 100644 --- a/middleware/session/session.go +++ b/middleware/session/session.go @@ -268,6 +268,8 @@ func (s *Session) refresh() { // sess.Save() will save the session data to the storage and update the // client cookie. // +// Checks if the session is being used in the handler, if so, it will not save the session. +// // Returns: // - error: An error if the save operation fails. // @@ -276,9 +278,11 @@ func (s *Session) refresh() { // err := s.Save() func (s *Session) Save() error { // If the session is being used in the handler, it should not be saved - if _, ok := s.ctx.Locals(key).(*Middleware); ok { - // Session is in use, so we do nothing and return - return nil + if m, ok := s.ctx.Locals(key).(*Middleware); ok { + if m.Session == s { + // Session is in use, so we do nothing and return + return nil + } } return s.saveSession() diff --git a/middleware/session/session_test.go b/middleware/session/session_test.go index ef8d469ade..5a628d4723 100644 --- a/middleware/session/session_test.go +++ b/middleware/session/session_test.go @@ -534,7 +534,9 @@ func Test_Session_Cookie(t *testing.T) { sess.Release() // cookie should be set on Save ( even if empty data ) - require.Len(t, ctx.Response().Header.PeekCookie(store.sessionName), 84) + cookie := ctx.Response().Header.PeekCookie(store.sessionName) + require.NotNil(t, cookie) + require.Regexp(t, `^session_id=[a-f0-9\-]{36}; max-age=\d+; path=/; SameSite=Lax$`, string(cookie)) } // go test -run Test_Session_Cookie_In_Response From ec5a698b076df1a7cc549d20e4d229671dbb8dd9 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Fri, 20 Sep 2024 15:09:04 -0300 Subject: [PATCH 50/79] docs(middleware/session): Update session middleware idle timeout and configuration --- middleware/session/config.go | 108 +++++++++++++++++++------------ middleware/session/middleware.go | 2 + 2 files changed, 70 insertions(+), 40 deletions(-) diff --git a/middleware/session/config.go b/middleware/session/config.go index b6dae196f8..3bfe9313f1 100644 --- a/middleware/session/config.go +++ b/middleware/session/config.go @@ -1,81 +1,100 @@ package session import ( + "fmt" "strings" "time" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/log" + "github.com/gofiber/session" "github.com/gofiber/utils/v2" ) -// Config defines the config for middleware. +// Config defines the configuration for the session middleware. type Config struct { - // Storage interface to store the session data - // Optional. Default value memory.New() + // Storage interface for storing session data. + // + // Optional. Default: memory.New() Storage fiber.Storage - // Next defines a function to skip this middleware when returned true. - // + // Next defines a function to skip this middleware when it returns true. // Optional. Default: nil - Next func(c fiber.Ctx) bool + Next func(c *fiber.Ctx) bool - // Store defines the session store + // Store defines the session store. // // Required. Store *Store - // ErrorHandler defines a function which is executed for errors + // ErrorHandler defines a function to handle errors. // // Optional. Default: nil ErrorHandler func(*fiber.Ctx, error) // KeyGenerator generates the session key. - // Optional. Default value utils.UUIDv4 + // + // Optional. Default: utils.UUIDv4 KeyGenerator func() string - // KeyLookup is a string in the form of ":" that is used - // to extract session id from the request. - // Possible values: "header:", "query:" or "cookie:" - // Optional. Default value "cookie:session_id". + // KeyLookup is a string in the format ":" used to extract the session ID from the request. + // + // Possible values: "header:", "query:", "cookie:" + // + // Optional. Default: "cookie:session_id" KeyLookup string - // Domain of the cookie. - // Optional. Default value "". + // CookieDomain defines the domain of the session cookie. + // + // Optional. Default: "" CookieDomain string - // Path of the cookie. - // Optional. Default value "". + // CookiePath defines the path of the session cookie. + // + // Optional. Default: "" CookiePath string - // Value of SameSite cookie. - // Optional. Default value "Lax". + // CookieSameSite specifies the SameSite attribute of the cookie. + // + // Optional. Default: "Lax" CookieSameSite string - // Source defines where to obtain the session id + // Source defines where to obtain the session ID. source Source - // The session name + // sessionName is the name of the session. sessionName string - // Allowed session idle duration - // Optional. Default value 30 * time.Minute. + // IdleTimeout defines the maximum duration of inactivity before the session expires. + // + // If set to a negative value, the session will never expire. + // Use this cautiously as it may lead to security issues. + // + // Note: The idle timeout is updated on each `Save()` call. If a middleware handler is used, `Save()` is called automatically. + // + // Optional. Default: 30 * time.Minute IdleTimeout time.Duration - // Indicates if cookie is secure. - // Optional. Default value false. + // CookieSecure specifies if the session cookie should be secure. + // + // Optional. Default: false CookieSecure bool - // Indicates if cookie is HTTP only. - // Optional. Default value false. + // CookieHTTPOnly specifies if the session cookie should be HTTP-only. + // + // Optional. Default: false CookieHTTPOnly bool - // Decides whether cookie should last for only the browser session. - // Ignores IdleTimeout if set to true - // Optional. Default value false. + // CookieSessionOnly determines if the cookie should expire when the browser session ends. + // + // If true, the cookie will be deleted when the browser is closed. + // Note: This will not delete the session data from the store. + // + // Optional. Default: false CookieSessionOnly bool } +// Source represents the type of session ID source. type Source string const ( @@ -84,12 +103,12 @@ const ( SourceURLQuery Source = "query" ) -// ConfigDefault is the default config +// ConfigDefault provides the default configuration. var ConfigDefault = Config{ IdleTimeout: 30 * time.Minute, KeyLookup: "cookie:session_id", KeyGenerator: utils.UUIDv4, - source: "cookie", + source: SourceCookie, sessionName: "session_id", } @@ -105,8 +124,8 @@ var ConfigDefault = Config{ func DefaultErrorHandler(c *fiber.Ctx, err error) { log.Errorf("session: %v", err) if c != nil { - if err := (*c).SendStatus(fiber.StatusInternalServerError); err != nil { - log.Errorf("session: %v", err) + if sendErr := (*c).SendStatus(fiber.StatusInternalServerError); sendErr != nil { + log.Errorf("session: %v", sendErr) } } } @@ -114,7 +133,7 @@ func DefaultErrorHandler(c *fiber.Ctx, err error) { // configDefault sets default values for the Config struct. // // Parameters: -// - config: Variadic parameter to override default config. +// - config: Variadic parameter to override the default config. // // Returns: // - Config: The configuration with default values set. @@ -124,15 +143,15 @@ func DefaultErrorHandler(c *fiber.Ctx, err error) { // cfg := configDefault() // cfg := configDefault(customConfig) func configDefault(config ...Config) Config { - // Return default config if nothing provided + // Return default config if none provided. if len(config) < 1 { return ConfigDefault } - // Override default config + // Override default config with provided config. cfg := config[0] - // Set default values + // Set default values where necessary. if int(cfg.IdleTimeout.Seconds()) <= 0 { cfg.IdleTimeout = ConfigDefault.IdleTimeout } @@ -143,10 +162,11 @@ func configDefault(config ...Config) Config { cfg.KeyGenerator = ConfigDefault.KeyGenerator } + // Parse KeyLookup into source and session name. selectors := strings.Split(cfg.KeyLookup, ":") const numSelectors = 2 if len(selectors) != numSelectors { - panic("[session] KeyLookup must in the form of :") + panic("[session] KeyLookup must be in the format ':'") } switch Source(selectors[0]) { case SourceCookie: @@ -156,9 +176,17 @@ func configDefault(config ...Config) Config { case SourceURLQuery: cfg.source = SourceURLQuery default: - panic("[session] source is not supported") + panic("[session] unsupported source in KeyLookup") } cfg.sessionName = selectors[1] return cfg } + +// Example for Config struct. +func ExampleConfig() { + cfg := session.Config{ + IdleTimeout: 10 * time.Minute, + } + fmt.Println(cfg.IdleTimeout) +} diff --git a/middleware/session/middleware.go b/middleware/session/middleware.go index 51ca5d32c9..5b5836cd59 100644 --- a/middleware/session/middleware.go +++ b/middleware/session/middleware.go @@ -1,3 +1,5 @@ +// Package session provides session management middleware for Fiber. +// This middleware allows you to manage user sessions, including storing session data in the store. package session import ( From 13a1eb43abac1b5fdc5b7c43a105017c8146a314 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Fri, 20 Sep 2024 15:18:07 -0300 Subject: [PATCH 51/79] test(middleware/session): Fix tests for updated panics --- middleware/session/config.go | 12 +----------- middleware/session/config_test.go | 4 ++-- 2 files changed, 3 insertions(+), 13 deletions(-) diff --git a/middleware/session/config.go b/middleware/session/config.go index 3bfe9313f1..6f492cf65d 100644 --- a/middleware/session/config.go +++ b/middleware/session/config.go @@ -1,13 +1,11 @@ package session import ( - "fmt" "strings" "time" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/log" - "github.com/gofiber/session" "github.com/gofiber/utils/v2" ) @@ -20,7 +18,7 @@ type Config struct { // Next defines a function to skip this middleware when it returns true. // Optional. Default: nil - Next func(c *fiber.Ctx) bool + Next func(c fiber.Ctx) bool // Store defines the session store. // @@ -182,11 +180,3 @@ func configDefault(config ...Config) Config { return cfg } - -// Example for Config struct. -func ExampleConfig() { - cfg := session.Config{ - IdleTimeout: 10 * time.Minute, - } - fmt.Println(cfg.IdleTimeout) -} diff --git a/middleware/session/config_test.go b/middleware/session/config_test.go index 45456f6ea0..80d04f9750 100644 --- a/middleware/session/config_test.go +++ b/middleware/session/config_test.go @@ -47,13 +47,13 @@ func TestDefaultErrorHandler(t *testing.T) { } func TestInvalidKeyLookupFormat(t *testing.T) { - require.PanicsWithValue(t, "[session] KeyLookup must in the form of :", func() { + require.PanicsWithValue(t, "[session] KeyLookup must be in the format ':'", func() { configDefault(Config{KeyLookup: "invalid_format"}) }) } func TestUnsupportedSource(t *testing.T) { - require.PanicsWithValue(t, "[session] source is not supported", func() { + require.PanicsWithValue(t, "[session] unsupported source in KeyLookup", func() { configDefault(Config{KeyLookup: "unsupported:session_id"}) }) } From 9d3b0322405d7ce69285652ead6c7ee1df1d0088 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Fri, 20 Sep 2024 16:40:40 -0300 Subject: [PATCH 52/79] refactor(middleware/session): Update session middleware initialization and saving --- middleware/session/middleware.go | 111 +++++++++++++++---------------- middleware/session/session.go | 32 ++++----- 2 files changed, 68 insertions(+), 75 deletions(-) diff --git a/middleware/session/middleware.go b/middleware/session/middleware.go index 5b5836cd59..b5857a3bdf 100644 --- a/middleware/session/middleware.go +++ b/middleware/session/middleware.go @@ -1,5 +1,5 @@ // Package session provides session management middleware for Fiber. -// This middleware allows you to manage user sessions, including storing session data in the store. +// This middleware handles user sessions, including storing session data in the store. package session import ( @@ -10,7 +10,7 @@ import ( "github.com/gofiber/fiber/v3/log" ) -// Middleware defines the session middleware configuration +// Middleware holds session data and configuration. type Middleware struct { Session *Session ctx *fiber.Ctx @@ -19,13 +19,14 @@ type Middleware struct { destroyed bool } -// key for looking up session middleware in request context +// Context key for session middleware lookup. const key = 0 var ( - // ErrTypeAssertionFailed is returned when the type assertion failed + // ErrTypeAssertionFailed occurs when a type assertion fails. ErrTypeAssertionFailed = errors.New("failed to type-assert to *Middleware") + // Pool for reusing middleware instances. middlewarePool = &sync.Pool{ New: func() any { return &Middleware{} @@ -33,7 +34,7 @@ var ( } ) -// New creates a new session middleware with the given configuration. +// New initializes session middleware with optional configuration. // // Parameters: // - config: Variadic parameter to override default config. @@ -44,18 +45,20 @@ var ( // Usage: // // app.Use(session.New()) +// +// Usage: +// +// app.Use(session.New()) func New(config ...Config) fiber.Handler { - var handler fiber.Handler if len(config) > 0 { - handler, _ = NewWithStore(config[0]) - } else { - handler, _ = NewWithStore() + handler, _ := NewWithStore(config[0]) + return handler } - + handler, _ := NewWithStore() return handler } -// NewWithStore returns a new session middleware with the given store. +// NewWithStore creates session middleware with an optional custom store. // // Parameters: // - config: Variadic parameter to override default config. @@ -75,29 +78,14 @@ func NewWithStore(config ...Config) (fiber.Handler, *Store) { } handler := func(c fiber.Ctx) error { - // Don't execute middleware if Next returns true if cfg.Next != nil && cfg.Next(c) { return c.Next() } - // Get the session - session, err := cfg.Store.getSession(c) - if err != nil { - return err - } - - // get a middleware from the pool + // Acquire session middleware m := acquireMiddleware() - m.mu.Lock() - m.config = cfg - m.Session = session - m.ctx = &c - - // Store the middleware in the context - c.Locals(key, m) - m.mu.Unlock() + m.initialize(c, cfg) - // Continue stack stackErr := c.Next() m.mu.RLock() @@ -105,50 +93,56 @@ func NewWithStore(config ...Config) (fiber.Handler, *Store) { m.mu.RUnlock() if !destroyed { - // Save the session - // This is done after the response is sent to the client - // It allows us to modify the session data during the request - // without having to worry about calling Save() on the session. - // - // It will also extend the session idle timeout automatically. - if err := session.saveSession(); err != nil { - if cfg.ErrorHandler != nil { - cfg.ErrorHandler(&c, err) - } else { - DefaultErrorHandler(&c, err) - } - } - - // Release the session back to the pool - releaseSession(session) + m.saveSession() } - // release the middleware back to the pool releaseMiddleware(m) - return stackErr } return handler, cfg.Store } -// acquireMiddleware returns a new Middleware from the pool. -// -// Returns: -// - *Middleware: The middleware object. -// -// Usage: -// -// m := acquireMiddleware() +// initialize sets up middleware for the request. +func (m *Middleware) initialize(c fiber.Ctx, cfg Config) { + m.mu.Lock() + defer m.mu.Unlock() + + session, err := cfg.Store.getSession(c) + if err != nil { + panic(err) // handle or log this error appropriately in production + } + + m.config = cfg + m.Session = session + m.ctx = &c + + c.Locals(key, m) +} + +// saveSession handles session saving and error management after the response. +func (m *Middleware) saveSession() { + if err := m.Session.saveSession(); err != nil { + if m.config.ErrorHandler != nil { + m.config.ErrorHandler(m.ctx, err) + } else { + DefaultErrorHandler(m.ctx, err) + } + } + + releaseSession(m.Session) +} + +// acquireMiddleware retrieves a middleware instance from the pool. func acquireMiddleware() *Middleware { - middleware, ok := middlewarePool.Get().(*Middleware) + m, ok := middlewarePool.Get().(*Middleware) if !ok { panic(ErrTypeAssertionFailed.Error()) } - return middleware + return m } -// releaseMiddleware returns a Middleware to the pool. +// releaseMiddleware resets and returns middleware to the pool. // // Parameters: // - m: The middleware object to release. @@ -289,8 +283,7 @@ func (m *Middleware) Reset() error { m.mu.Lock() defer m.mu.Unlock() - err := m.Session.Reset() - return err + return m.Session.Reset() } // Store returns the session store. diff --git a/middleware/session/session.go b/middleware/session/session.go index 7cb7a59d9f..ce29a01b7f 100644 --- a/middleware/session/session.go +++ b/middleware/session/session.go @@ -17,7 +17,7 @@ type Session struct { ctx fiber.Ctx // fiber context config *Store // store configuration data *data // key value data - byteBuffer *bytes.Buffer // byte buffer for the en- and decode + byteBuffer *bytes.Buffer // byte buffer for encoding/decoding id string // session id idleTimeout time.Duration // idleTimeout of this session mu sync.RWMutex // Mutex to protect non-data fields @@ -26,7 +26,9 @@ type Session struct { var sessionPool = sync.Pool{ New: func() any { - return new(Session) + return &Session{ + byteBuffer: new(bytes.Buffer), + } }, } @@ -43,9 +45,6 @@ func acquireSession() *Session { if s.data == nil { s.data = acquireData() } - if s.byteBuffer == nil { - s.byteBuffer = new(bytes.Buffer) - } s.fresh = true return s } @@ -90,7 +89,7 @@ func releaseSession(s *Session) { sessionPool.Put(s) } -// Fresh returns true if the current session is new. +// Fresh returns whether the session is new // // Returns: // - bool: True if the session is fresh, otherwise false. @@ -104,7 +103,7 @@ func (s *Session) Fresh() bool { return s.fresh } -// ID returns the session id. +// ID returns the session ID // // Returns: // - string: The session ID. @@ -263,12 +262,10 @@ func (s *Session) refresh() { s.fresh = true } -// Save updates the storage and client cookie. -// -// sess.Save() will save the session data to the storage and update the -// client cookie. +// Save saves the session data and updates the cookie // -// Checks if the session is being used in the handler, if so, it will not save the session. +// Note: If the session is being used in the handler, calling Save will have +// no effect and the session will automatically be saved when the handler returns. // // Returns: // - error: An error if the save operation fails. @@ -288,6 +285,7 @@ func (s *Session) Save() error { return s.saveSession() } +// saveSession encodes session data to saves it to storage. func (s *Session) saveSession() error { if s.data == nil { return nil @@ -296,7 +294,7 @@ func (s *Session) saveSession() error { s.mu.Lock() defer s.mu.Unlock() - // Check is the session has an idle timeout + // Set idleTimeout if not already set if s.idleTimeout <= 0 { s.idleTimeout = s.config.IdleTimeout } @@ -304,9 +302,11 @@ func (s *Session) saveSession() error { // Update client cookie s.setSession() - // Convert data to bytes + // Encode session data encCache := gob.NewEncoder(s.byteBuffer) + s.data.RLock() err := encCache.Encode(&s.data.Data) + s.data.RUnlock() if err != nil { return fmt.Errorf("failed to encode data: %w", err) } @@ -334,7 +334,7 @@ func (s *Session) Keys() []string { return s.data.Keys() } -// SetIdleTimeout sets a specific idle timeout for the session. +// SetIdleTimeout used when saving the session on the next call to `Save()`. // // Parameters: // - idleTimeout: The duration for the idle timeout. @@ -411,7 +411,7 @@ func (s *Session) delSession() { } } -// decodeSessionData decodes the session data from raw bytes. +// decodeSessionData decodes session data from raw bytes // // Parameters: // - rawData: The raw byte data to decode. From 9762767a4bd14b2294e16e86e41c93ef8d3901d0 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Fri, 20 Sep 2024 16:41:52 -0300 Subject: [PATCH 53/79] refactor(middleware/session): Remove unnecessary comment about negative IdleTimeout value --- middleware/session/config.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/middleware/session/config.go b/middleware/session/config.go index 6f492cf65d..0231d3a8c5 100644 --- a/middleware/session/config.go +++ b/middleware/session/config.go @@ -65,9 +65,6 @@ type Config struct { // IdleTimeout defines the maximum duration of inactivity before the session expires. // - // If set to a negative value, the session will never expire. - // Use this cautiously as it may lead to security issues. - // // Note: The idle timeout is updated on each `Save()` call. If a middleware handler is used, `Save()` is called automatically. // // Optional. Default: 30 * time.Minute From e59905f22cdbe42ea60c17b5bf69e7e5ec3d7318 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Wed, 25 Sep 2024 13:56:08 -0300 Subject: [PATCH 54/79] refactor(middleware/session): Update session middleware make NewStore public --- docs/middleware/session.md | 34 +++++++++++++++++---- docs/whats_new.md | 2 +- middleware/csrf/csrf_test.go | 6 ++-- middleware/session/config.go | 2 +- middleware/session/middleware.go | 25 +++++++++------- middleware/session/session.go | 2 +- middleware/session/session_test.go | 48 +++++++++++++++--------------- middleware/session/store.go | 14 +++++++-- middleware/session/store_test.go | 16 +++++----- 9 files changed, 94 insertions(+), 55 deletions(-) diff --git a/docs/middleware/session.md b/docs/middleware/session.md index 96cdb47169..60df9f1252 100644 --- a/docs/middleware/session.md +++ b/docs/middleware/session.md @@ -35,7 +35,7 @@ As of v3, we recommend using the middleware handler for session management. Howe ### v2 to v3 -- **Function Signature Change**: In v3, the `New` function now returns a middleware handler instead of a `*Store`. To access the store, use the `Store` method on `*Middleware` (obtained from `session.FromContext(c)` in a handler) or use `NewWithStore`. +- **Function Signature Change**: In v3, the `New` function now returns a middleware handler instead of a `*Store`. To access the store, use the `Store` method on `*Middleware` (obtained from `session.FromContext(c)` in a handler) or use `NewStore` or `NewWithStore`. - **Session Lifecycle Management**: The `*Store.Save` method no longer releases the instance automatically. You must manually call `sess.Release()` after using the session to manage its lifecycle properly. @@ -47,7 +47,7 @@ For more details about Fiber v3, see [What’s New](https://github.com/gofiber/f To convert a v2 example to use the v3 legacy approach, follow these steps: -1. **Initialize with Store**: Use `session.NewWithStore()` to obtain both the middleware handler and store. +1. **Initialize with Store**: Use `session.NewStore()` to obtain a store. 2. **Retrieve Session**: Access the session store using the `store.Get(c)` method. 3. **Release Session**: Ensure that you call `sess.Release()` after you are done with the session to manage its lifecycle. @@ -87,7 +87,7 @@ app.Get("/", func(c *fiber.Ctx) error { **v3 Legacy Approach:** ```go -_, store := session.NewWithStore() +store := session.NewStore() app.Get("/", func(c *fiber.Ctx) error { sess, err := store.Get(c) @@ -163,7 +163,7 @@ type Session struct {} ### Store -Handles session data management and is created using `NewWithStore` or by accessing the `Store` method of a middleware instance. +Handles session data management and is created using `NewStore`, `NewWithStore` or by accessing the `Store` method of a middleware instance. ```go type Store struct { @@ -316,7 +316,7 @@ import ( func main() { app := fiber.New() - _, sessionStore := session.NewWithStore() + sessionStore := session.NewStore() app.Use(csrf.New(csrf.Config{ Store: sessionStore, @@ -327,6 +327,7 @@ func main() { if err != nil { return c.SendStatus(fiber.StatusInternalServerError) } + defer sess.Release() name, ok := sess.Get("name").(string) if !ok { @@ -336,6 +337,29 @@ func main() { return c.SendString("Welcome " + name) }) + app.Post("/login", func(c *fiber.Ctx) error { + sess, err := sessionStore.Get(c) + if err != nil { + return c.SendStatus(fiber.StatusInternalServerError) + } + defer sess.Release() + + if !sess.Fresh() { + if err := sess.Regenerate(); err != nil { + return c.SendStatus(fiber.StatusInternalServerError) + } + } + + sess.Set("name", "John Doe") + + err = sess.Save() + if err != nil { + return c.SendStatus(fiber.StatusInternalServerError) + } + + return c.SendString("Logged in!") + }) + app.Listen(":3000") } ``` diff --git a/docs/whats_new.md b/docs/whats_new.md index 4cb6b20b18..400d8a4fe9 100644 --- a/docs/whats_new.md +++ b/docs/whats_new.md @@ -318,7 +318,7 @@ The Session middleware has undergone key changes in v3 to improve functionality #### Key Updates -- **New Middleware Handler**: The `New` function now returns a middleware handler instead of a `*Store`. To access the session store, use the `Store` method on the middleware, or opt for `NewWithStore` for custom store integration. +- **New Middleware Handler**: The `New` function now returns a middleware handler instead of a `*Store`. To access the session store, use the `Store` method on the middleware, or opt for `NewStore` or `NewWithStore` for custom store integration. - **Manual Session Release**: Session instances are no longer automatically released after being saved. To ensure proper lifecycle management, you must manually call `sess.Release()`. diff --git a/middleware/csrf/csrf_test.go b/middleware/csrf/csrf_test.go index 133752abba..0e486d4830 100644 --- a/middleware/csrf/csrf_test.go +++ b/middleware/csrf/csrf_test.go @@ -70,7 +70,7 @@ func Test_CSRF_WithSession(t *testing.T) { t.Parallel() // session store - _, store := session.NewWithStore(session.Config{ + store := session.NewStore(session.Config{ KeyLookup: "cookie:_session", }) @@ -260,7 +260,7 @@ func Test_CSRF_ExpiredToken_WithSession(t *testing.T) { t.Parallel() // session store - _, store := session.NewWithStore(session.Config{ + store := session.NewStore(session.Config{ KeyLookup: "cookie:_session", }) @@ -1131,7 +1131,7 @@ func Test_CSRF_DeleteToken_WithSession(t *testing.T) { t.Parallel() // session store - _, store := session.NewWithStore(session.Config{ + store := session.NewStore(session.Config{ KeyLookup: "cookie:_session", }) diff --git a/middleware/session/config.go b/middleware/session/config.go index 0231d3a8c5..a4a509ffcc 100644 --- a/middleware/session/config.go +++ b/middleware/session/config.go @@ -147,7 +147,7 @@ func configDefault(config ...Config) Config { cfg := config[0] // Set default values where necessary. - if int(cfg.IdleTimeout.Seconds()) <= 0 { + if cfg.IdleTimeout <= 0 { cfg.IdleTimeout = ConfigDefault.IdleTimeout } if cfg.KeyLookup == "" { diff --git a/middleware/session/middleware.go b/middleware/session/middleware.go index b5857a3bdf..fbc08d876c 100644 --- a/middleware/session/middleware.go +++ b/middleware/session/middleware.go @@ -20,7 +20,12 @@ type Middleware struct { } // Context key for session middleware lookup. -const key = 0 +type middlewareKey int + +const ( + // middlewareContextKey is the key used to store the *Middleware in the context locals. + middlewareContextKey middlewareKey = iota +) var ( // ErrTypeAssertionFailed occurs when a type assertion fails. @@ -74,7 +79,7 @@ func NewWithStore(config ...Config) (fiber.Handler, *Store) { cfg := configDefault(config...) if cfg.Store == nil { - cfg.Store = newStore(cfg) + cfg.Store = NewStore(cfg) } handler := func(c fiber.Ctx) error { @@ -117,7 +122,7 @@ func (m *Middleware) initialize(c fiber.Ctx, cfg Config) { m.Session = session m.ctx = &c - c.Locals(key, m) + c.Locals(middlewareContextKey, m) } // saveSession handles session saving and error management after the response. @@ -172,7 +177,7 @@ func releaseMiddleware(m *Middleware) { // // m := session.FromContext(c) func FromContext(c fiber.Ctx) *Middleware { - m, ok := c.Locals(key).(*Middleware) + m, ok := c.Locals(middlewareContextKey).(*Middleware) if !ok { // TODO: since this may be called we may not want to log this except in debug mode? log.Warn("session: Session middleware not registered. See https://docs.gofiber.io/middleware/session") @@ -190,11 +195,11 @@ func FromContext(c fiber.Ctx) *Middleware { // Usage: // // m.Set("key", "value") -func (m *Middleware) Set(key string, value any) { +func (m *Middleware) Set(middlewareContextKey string, value any) { m.mu.Lock() defer m.mu.Unlock() - m.Session.Set(key, value) + m.Session.Set(middlewareContextKey, value) } // Get retrieves a value from the session by key. @@ -208,11 +213,11 @@ func (m *Middleware) Set(key string, value any) { // Usage: // // value := m.Get("key") -func (m *Middleware) Get(key string) any { +func (m *Middleware) Get(middlewareContextKey string) any { m.mu.RLock() defer m.mu.RUnlock() - return m.Session.Get(key) + return m.Session.Get(middlewareContextKey) } // Delete removes a key-value pair from the session. @@ -223,11 +228,11 @@ func (m *Middleware) Get(key string) any { // Usage: // // m.Delete("key") -func (m *Middleware) Delete(key string) { +func (m *Middleware) Delete(middlewareContextKey string) { m.mu.Lock() defer m.mu.Unlock() - m.Session.Delete(key) + m.Session.Delete(middlewareContextKey) } // Destroy destroys the session. diff --git a/middleware/session/session.go b/middleware/session/session.go index ce29a01b7f..bcacd351f6 100644 --- a/middleware/session/session.go +++ b/middleware/session/session.go @@ -275,7 +275,7 @@ func (s *Session) refresh() { // err := s.Save() func (s *Session) Save() error { // If the session is being used in the handler, it should not be saved - if m, ok := s.ctx.Locals(key).(*Middleware); ok { + if m, ok := s.ctx.Locals(middlewareContextKey).(*Middleware); ok { if m.Session == s { // Session is in use, so we do nothing and return return nil diff --git a/middleware/session/session_test.go b/middleware/session/session_test.go index 5a628d4723..1787a610fe 100644 --- a/middleware/session/session_test.go +++ b/middleware/session/session_test.go @@ -17,7 +17,7 @@ func Test_Session(t *testing.T) { t.Parallel() // session store - store := newStore() + store := NewStore() // fiber instance app := fiber.New() @@ -118,7 +118,7 @@ func Test_Session_Types(t *testing.T) { t.Parallel() // session store - store := newStore() + store := NewStore() // fiber instance app := fiber.New() @@ -293,7 +293,7 @@ func Test_Session_Types(t *testing.T) { func Test_Session_Store_Reset(t *testing.T) { t.Parallel() // session store - store := newStore() + store := NewStore() // fiber instance app := fiber.New() // fiber context @@ -334,7 +334,7 @@ func Test_Session_Save(t *testing.T) { t.Run("save to cookie", func(t *testing.T) { t.Parallel() // session store - store := newStore() + store := NewStore() // fiber instance app := fiber.New() // fiber context @@ -355,7 +355,7 @@ func Test_Session_Save(t *testing.T) { t.Run("save to header", func(t *testing.T) { t.Parallel() // session store - store := newStore(Config{ + store := NewStore(Config{ KeyLookup: "header:session_id", }) // fiber instance @@ -387,7 +387,7 @@ func Test_Session_Save_Expiration(t *testing.T) { const sessionDuration = 5 * time.Second // session store - store := newStore() + store := NewStore() // fiber instance app := fiber.New() // fiber context @@ -446,7 +446,7 @@ func Test_Session_Destroy(t *testing.T) { t.Run("destroy from cookie", func(t *testing.T) { t.Parallel() // session store - store := newStore() + store := NewStore() // fiber instance app := fiber.New() // fiber context @@ -467,7 +467,7 @@ func Test_Session_Destroy(t *testing.T) { t.Run("destroy from header", func(t *testing.T) { t.Parallel() // session store - store := newStore(Config{ + store := NewStore(Config{ KeyLookup: "header:session_id", }) // fiber instance @@ -507,11 +507,11 @@ func Test_Session_Destroy(t *testing.T) { func Test_Session_Custom_Config(t *testing.T) { t.Parallel() - store := newStore(Config{IdleTimeout: time.Hour, KeyGenerator: func() string { return "very random" }}) + store := NewStore(Config{IdleTimeout: time.Hour, KeyGenerator: func() string { return "very random" }}) require.Equal(t, time.Hour, store.IdleTimeout) require.Equal(t, "very random", store.KeyGenerator()) - store = newStore(Config{IdleTimeout: 0}) + store = NewStore(Config{IdleTimeout: 0}) require.Equal(t, ConfigDefault.IdleTimeout, store.IdleTimeout) } @@ -519,7 +519,7 @@ func Test_Session_Custom_Config(t *testing.T) { func Test_Session_Cookie(t *testing.T) { t.Parallel() // session store - store := newStore() + store := NewStore() // fiber instance app := fiber.New() // fiber context @@ -543,7 +543,7 @@ func Test_Session_Cookie(t *testing.T) { // Regression: https://github.com/gofiber/fiber/pull/1191 func Test_Session_Cookie_In_Middleware_Chain(t *testing.T) { t.Parallel() - store := newStore() + store := NewStore() app := fiber.New() // fiber context @@ -575,7 +575,7 @@ func Test_Session_Cookie_In_Middleware_Chain(t *testing.T) { // Regression: https://github.com/gofiber/fiber/issues/1365 func Test_Session_Deletes_Single_Key(t *testing.T) { t.Parallel() - store := newStore() + store := NewStore() app := fiber.New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) @@ -617,7 +617,7 @@ func Test_Session_Reset(t *testing.T) { app := fiber.New() // session store - store := newStore() + store := NewStore() t.Run("reset session data and id, and set fresh to be true", func(t *testing.T) { t.Parallel() @@ -691,7 +691,7 @@ func Test_Session_Regenerate(t *testing.T) { t.Run("set fresh to be true when regenerating a session", func(t *testing.T) { t.Parallel() // session store - store := newStore() + store := NewStore() // a random session uuid originalSessionUUIDString := "" // fiber context @@ -740,7 +740,7 @@ func Test_Session_Regenerate(t *testing.T) { // go test -v -run=^$ -bench=Benchmark_Session -benchmem -count=4 func Benchmark_Session(b *testing.B) { b.Run("default", func(b *testing.B) { - app, store := fiber.New(), newStore() + app, store := fiber.New(), NewStore() c := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(c) c.Request().Header.SetCookie(store.sessionName, "12356789") @@ -758,7 +758,7 @@ func Benchmark_Session(b *testing.B) { b.Run("storage", func(b *testing.B) { app := fiber.New() - store := newStore(Config{ + store := NewStore(Config{ Storage: memory.New(), }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) @@ -780,7 +780,7 @@ func Benchmark_Session(b *testing.B) { // go test -v -run=^$ -bench=Benchmark_Session_Parallel -benchmem -count=4 func Benchmark_Session_Parallel(b *testing.B) { b.Run("default", func(b *testing.B) { - app, store := fiber.New(), newStore() + app, store := fiber.New(), NewStore() b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { @@ -801,7 +801,7 @@ func Benchmark_Session_Parallel(b *testing.B) { b.Run("storage", func(b *testing.B) { app := fiber.New() - store := newStore(Config{ + store := NewStore(Config{ Storage: memory.New(), }) b.ReportAllocs() @@ -826,7 +826,7 @@ func Benchmark_Session_Parallel(b *testing.B) { // go test -v -run=^$ -bench=Benchmark_Session_Asserted -benchmem -count=4 func Benchmark_Session_Asserted(b *testing.B) { b.Run("default", func(b *testing.B) { - app, store := fiber.New(), newStore() + app, store := fiber.New(), NewStore() c := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(c) c.Request().Header.SetCookie(store.sessionName, "12356789") @@ -845,7 +845,7 @@ func Benchmark_Session_Asserted(b *testing.B) { b.Run("storage", func(b *testing.B) { app := fiber.New() - store := newStore(Config{ + store := NewStore(Config{ Storage: memory.New(), }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) @@ -868,7 +868,7 @@ func Benchmark_Session_Asserted(b *testing.B) { // go test -v -run=^$ -bench=Benchmark_Session_Asserted_Parallel -benchmem -count=4 func Benchmark_Session_Asserted_Parallel(b *testing.B) { b.Run("default", func(b *testing.B) { - app, store := fiber.New(), newStore() + app, store := fiber.New(), NewStore() b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { @@ -888,7 +888,7 @@ func Benchmark_Session_Asserted_Parallel(b *testing.B) { b.Run("storage", func(b *testing.B) { app := fiber.New() - store := newStore(Config{ + store := NewStore(Config{ Storage: memory.New(), }) b.ReportAllocs() @@ -913,7 +913,7 @@ func Benchmark_Session_Asserted_Parallel(b *testing.B) { func Test_Session_Concurrency(t *testing.T) { t.Parallel() app := fiber.New() - store := newStore() + store := NewStore() var wg sync.WaitGroup errChan := make(chan error, 10) // Buffered channel to collect errors diff --git a/middleware/session/store.go b/middleware/session/store.go index 9fe2ff868e..9d525328a3 100644 --- a/middleware/session/store.go +++ b/middleware/session/store.go @@ -30,7 +30,17 @@ type Store struct { } // New creates a new session store with the provided configuration. -func newStore(config ...Config) *Store { +// +// Parameters: +// - config: Variadic parameter to override default config. +// +// Returns: +// - *Store: The session store. +// +// Usage: +// +// store := session.New() +func NewStore(config ...Config) *Store { // Set default config cfg := configDefault(config...) @@ -76,7 +86,7 @@ func (*Store) RegisterType(i any) { func (s *Store) Get(c fiber.Ctx) (*Session, error) { // If session is already loaded in the context, // it should not be loaded again - _, ok := c.Locals(key).(*Middleware) + _, ok := c.Locals(middlewareContextKey).(*Middleware) if ok { return nil, ErrSessionAlreadyLoadedByMiddleware } diff --git a/middleware/session/store_test.go b/middleware/session/store_test.go index 06a744d138..f57cc38fc2 100644 --- a/middleware/session/store_test.go +++ b/middleware/session/store_test.go @@ -20,7 +20,7 @@ func Test_Store_getSessionID(t *testing.T) { t.Run("from cookie", func(t *testing.T) { t.Parallel() // session store - store := newStore() + store := NewStore() // fiber context ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) @@ -33,7 +33,7 @@ func Test_Store_getSessionID(t *testing.T) { t.Run("from header", func(t *testing.T) { t.Parallel() // session store - store := newStore(Config{ + store := NewStore(Config{ KeyLookup: "header:session_id", }) // fiber context @@ -48,7 +48,7 @@ func Test_Store_getSessionID(t *testing.T) { t.Run("from url query", func(t *testing.T) { t.Parallel() // session store - store := newStore(Config{ + store := NewStore(Config{ KeyLookup: "query:session_id", }) // fiber context @@ -73,7 +73,7 @@ func Test_Store_Get(t *testing.T) { t.Run("session should be re-generated if it is invalid", func(t *testing.T) { t.Parallel() // session store - store := newStore() + store := NewStore() // fiber context ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) @@ -93,7 +93,7 @@ func Test_Store_DeleteSession(t *testing.T) { // fiber instance app := fiber.New() // session store - store := newStore() + store := NewStore() // fiber context ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) @@ -126,7 +126,7 @@ func TestStore_Get_SessionAlreadyLoaded(t *testing.T) { // Mock middleware and set it in the context middleware := &Middleware{} - ctx.Locals(key, middleware) + ctx.Locals(middlewareContextKey, middleware) // Create a new store store := &Store{} @@ -141,7 +141,7 @@ func TestStore_Get_SessionAlreadyLoaded(t *testing.T) { func TestStore_Delete(t *testing.T) { // Create a new store - store := newStore() + store := NewStore() t.Run("delete with empty session ID", func(t *testing.T) { err := store.Delete("") @@ -158,7 +158,7 @@ func TestStore_Delete(t *testing.T) { func Test_Store_GetSessionByID(t *testing.T) { t.Parallel() // Create a new store - store := newStore() + store := NewStore() t.Run("empty session ID", func(t *testing.T) { t.Parallel() From 8716c95a746bb8996369510880dbe3788ba2888d Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Wed, 25 Sep 2024 14:25:45 -0300 Subject: [PATCH 55/79] refactor(middleware/session): Update session middleware Set, Get, and Delete methods Refactor the Set, Get, and Delete methods in the session middleware to use more descriptive parameter names. Instead of using "middlewareContextKey", the methods now use "key" to represent the key of the session value. This improves the readability and clarity of the code. --- middleware/session/middleware.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/middleware/session/middleware.go b/middleware/session/middleware.go index fbc08d876c..5ff9f5fba1 100644 --- a/middleware/session/middleware.go +++ b/middleware/session/middleware.go @@ -195,11 +195,11 @@ func FromContext(c fiber.Ctx) *Middleware { // Usage: // // m.Set("key", "value") -func (m *Middleware) Set(middlewareContextKey string, value any) { +func (m *Middleware) Set(key string, value any) { m.mu.Lock() defer m.mu.Unlock() - m.Session.Set(middlewareContextKey, value) + m.Session.Set(key, value) } // Get retrieves a value from the session by key. @@ -213,11 +213,11 @@ func (m *Middleware) Set(middlewareContextKey string, value any) { // Usage: // // value := m.Get("key") -func (m *Middleware) Get(middlewareContextKey string) any { +func (m *Middleware) Get(key string) any { m.mu.RLock() defer m.mu.RUnlock() - return m.Session.Get(middlewareContextKey) + return m.Session.Get(key) } // Delete removes a key-value pair from the session. @@ -228,11 +228,11 @@ func (m *Middleware) Get(middlewareContextKey string) any { // Usage: // // m.Delete("key") -func (m *Middleware) Delete(middlewareContextKey string) { +func (m *Middleware) Delete(key string) { m.mu.Lock() defer m.mu.Unlock() - m.Session.Delete(middlewareContextKey) + m.Session.Delete(key) } // Destroy destroys the session. From 951691da2ecee382c411bd45261dcb8413899ed6 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Thu, 26 Sep 2024 14:19:15 -0300 Subject: [PATCH 56/79] feat(middleware/session): AbsoluteTimeout and key any --- middleware/session/config.go | 7 ++ middleware/session/data.go | 18 +-- middleware/session/data_msgp.go | 39 +++--- middleware/session/middleware.go | 6 +- middleware/session/session.go | 60 +++++++++- middleware/session/session_test.go | 183 ++++++++++++++++++++++++++++- middleware/session/store.go | 26 +++- 7 files changed, 301 insertions(+), 38 deletions(-) diff --git a/middleware/session/config.go b/middleware/session/config.go index a4a509ffcc..992173b4db 100644 --- a/middleware/session/config.go +++ b/middleware/session/config.go @@ -70,6 +70,13 @@ type Config struct { // Optional. Default: 30 * time.Minute IdleTimeout time.Duration + // AbsoluteTimeout defines the maximum duration of the session before it expires. + // + // If set to 0, the session will not have an absolute timeout, and will expire after the idle timeout. + // + // Optional. Default: 0 + AbsoluteTimeout time.Duration + // CookieSecure specifies if the session cookie should be secure. // // Optional. Default: false diff --git a/middleware/session/data.go b/middleware/session/data.go index 93f7c06f57..2c51ac39bf 100644 --- a/middleware/session/data.go +++ b/middleware/session/data.go @@ -8,14 +8,14 @@ import ( // //go:generate msgp -o=data_msgp.go -tests=true -unexported type data struct { - Data map[string]any + Data map[any]any sync.RWMutex `msg:"-"` } var dataPool = sync.Pool{ New: func() any { d := new(data) - d.Data = make(map[string]any) + d.Data = make(map[any]any) return d }, } @@ -45,7 +45,7 @@ func acquireData() *data { func (d *data) Reset() { d.Lock() defer d.Unlock() - d.Data = make(map[string]any) + d.Data = make(map[any]any) } // Get retrieves a value from the data map by key. @@ -59,7 +59,7 @@ func (d *data) Reset() { // Usage: // // value := d.Get("key") -func (d *data) Get(key string) any { +func (d *data) Get(key any) any { d.RLock() defer d.RUnlock() return d.Data[key] @@ -74,7 +74,7 @@ func (d *data) Get(key string) any { // Usage: // // d.Set("key", "value") -func (d *data) Set(key string, value any) { +func (d *data) Set(key any, value any) { d.Lock() defer d.Unlock() d.Data[key] = value @@ -88,7 +88,7 @@ func (d *data) Set(key string, value any) { // Usage: // // d.Delete("key") -func (d *data) Delete(key string) { +func (d *data) Delete(key any) { d.Lock() defer d.Unlock() delete(d.Data, key) @@ -97,15 +97,15 @@ func (d *data) Delete(key string) { // Keys retrieves all keys in the data map. // // Returns: -// - []string: A slice of all keys in the data map. +// - []any: A slice of all keys in the data map. // // Usage: // // keys := d.Keys() -func (d *data) Keys() []string { +func (d *data) Keys() []any { d.RLock() defer d.RUnlock() - keys := make([]string, 0, len(d.Data)) + keys := make([]any, 0, len(d.Data)) for k := range d.Data { keys = append(keys, k) } diff --git a/middleware/session/data_msgp.go b/middleware/session/data_msgp.go index ce3af1bd17..934f16ba91 100644 --- a/middleware/session/data_msgp.go +++ b/middleware/session/data_msgp.go @@ -43,7 +43,7 @@ func (z *data) DecodeMsg(dc *msgp.Reader) (err error) { return } if z.Data == nil { - z.Data = make(map[string]interface{}, zb0002) + z.Data = make(map[any]any, zb0002) } else if len(z.Data) > 0 { for key := range z.Data { delete(z.Data, key) @@ -51,9 +51,9 @@ func (z *data) DecodeMsg(dc *msgp.Reader) (err error) { } for zb0002 > 0 { zb0002-- - var za0001 string - var za0002 interface{} - za0001, err = dc.ReadString() + var za0001 any + var za0002 any + za0001, err = dc.ReadIntf() if err != nil { err = msgp.WrapError(err, "Data") return @@ -101,14 +101,18 @@ func (z *data) EncodeMsg(en *msgp.Writer) (err error) { return } for za0001, za0002 := range z.Data { - err = en.WriteString(za0001) + keyStr, ok := za0001.(string) + if !ok { + return msgp.WrapError(err, "Data", za0001) + } + err = en.WriteString(keyStr) if err != nil { err = msgp.WrapError(err, "Data") return } err = en.WriteIntf(za0002) if err != nil { - err = msgp.WrapError(err, "Data", za0001) + err = msgp.WrapError(err, "Data", keyStr) return } } @@ -135,10 +139,14 @@ func (z *data) MarshalMsg(b []byte) (o []byte, err error) { o = append(o, 0x81, 0xa4, 0x44, 0x61, 0x74, 0x61) o = msgp.AppendMapHeader(o, uint32(len(z.Data))) for za0001, za0002 := range z.Data { - o = msgp.AppendString(o, za0001) + keyStr, ok := za0001.(string) + if !ok { + return nil, msgp.WrapError(err, "Data", za0001) + } + o = msgp.AppendString(o, keyStr) o, err = msgp.AppendIntf(o, za0002) if err != nil { - err = msgp.WrapError(err, "Data", za0001) + err = msgp.WrapError(err, "Data", keyStr) return } } @@ -183,17 +191,17 @@ func (z *data) UnmarshalMsg(bts []byte) (o []byte, err error) { return } if z.Data == nil { - z.Data = make(map[string]interface{}, zb0002) + z.Data = make(map[any]any, zb0002) } else if len(z.Data) > 0 { for key := range z.Data { delete(z.Data, key) } } for zb0002 > 0 { - var za0001 string - var za0002 interface{} + var za0001 any + var za0002 any zb0002-- - za0001, bts, err = msgp.ReadStringBytes(bts) + za0001, bts, err = msgp.ReadIntfBytes(bts) if err != nil { err = msgp.WrapError(err, "Data") return @@ -230,8 +238,11 @@ func (z *data) Msgsize() (s int) { s = 1 + 5 + msgp.MapHeaderSize if z.Data != nil { for za0001, za0002 := range z.Data { - _ = za0002 - s += msgp.StringPrefixSize + len(za0001) + msgp.GuessSize(za0002) + keyStr, ok := za0001.(string) + if !ok { + continue + } + s += msgp.StringPrefixSize + len(keyStr) + msgp.GuessSize(za0002) } } return diff --git a/middleware/session/middleware.go b/middleware/session/middleware.go index 5ff9f5fba1..a4fb6128bf 100644 --- a/middleware/session/middleware.go +++ b/middleware/session/middleware.go @@ -195,7 +195,7 @@ func FromContext(c fiber.Ctx) *Middleware { // Usage: // // m.Set("key", "value") -func (m *Middleware) Set(key string, value any) { +func (m *Middleware) Set(key any, value any) { m.mu.Lock() defer m.mu.Unlock() @@ -213,7 +213,7 @@ func (m *Middleware) Set(key string, value any) { // Usage: // // value := m.Get("key") -func (m *Middleware) Get(key string) any { +func (m *Middleware) Get(key any) any { m.mu.RLock() defer m.mu.RUnlock() @@ -228,7 +228,7 @@ func (m *Middleware) Get(key string) any { // Usage: // // m.Delete("key") -func (m *Middleware) Delete(key string) { +func (m *Middleware) Delete(key any) { m.mu.Lock() defer m.mu.Unlock() diff --git a/middleware/session/session.go b/middleware/session/session.go index bcacd351f6..286602c233 100644 --- a/middleware/session/session.go +++ b/middleware/session/session.go @@ -24,6 +24,13 @@ type Session struct { fresh bool // if new session } +type expirationKeyType int + +const ( + // sessionIDContextKey is the key used to store the session ID in the context locals. + expirationKey expirationKeyType = iota +) + var sessionPool = sync.Pool{ New: func() any { return &Session{ @@ -128,7 +135,7 @@ func (s *Session) ID() string { // Usage: // // value := s.Get("key") -func (s *Session) Get(key string) any { +func (s *Session) Get(key any) any { if s.data == nil { return nil } @@ -144,7 +151,7 @@ func (s *Session) Get(key string) any { // Usage: // // s.Set("key", "value") -func (s *Session) Set(key string, val any) { +func (s *Session) Set(key any, val any) { if s.data == nil { return } @@ -159,7 +166,7 @@ func (s *Session) Set(key string, val any) { // Usage: // // s.Delete("key") -func (s *Session) Delete(key string) { +func (s *Session) Delete(key any) { if s.data == nil { return } @@ -327,9 +334,9 @@ func (s *Session) saveSession() error { // Usage: // // keys := s.Keys() -func (s *Session) Keys() []string { +func (s *Session) Keys() []any { if s.data == nil { - return []string{} + return []any{} } return s.data.Keys() } @@ -430,3 +437,46 @@ func (s *Session) decodeSessionData(rawData []byte) error { } return nil } + +// expiration returns the session expiration time or a zero time if not set. +// +// Returns: +// - time.Time: The session expiration time, or a zero time if not set. +// +// Usage: +// +// expiration := s.expiration() +func (s *Session) expiration() time.Time { + s.mu.RLock() + defer s.mu.RUnlock() + expiration, ok := s.Get(expirationKey).(time.Time) + if ok { + return expiration + } + return time.Time{} +} + +// isExpired returns true if the session is expired. +// +// If the session expiration time is zero, the session is considered to never expire. +// +// Returns: +// - bool: True if the session is expired, otherwise false. +func (s *Session) isExpired() bool { + expiration := s.expiration() + return !expiration.IsZero() && time.Now().After(expiration) +} + +// setExpiration sets the session expiration time. +// +// Parameters: +// - expiration: The session expiration time. +// +// Usage: +// +// s.setExpiration(time.Now().Add(time.Hour)) +func (s *Session) setExpiration(expiration time.Time) { + s.mu.Lock() + defer s.mu.Unlock() + s.Set(expirationKey, expiration) +} diff --git a/middleware/session/session_test.go b/middleware/session/session_test.go index 1787a610fe..b503cb9c99 100644 --- a/middleware/session/session_test.go +++ b/middleware/session/session_test.go @@ -46,7 +46,7 @@ func Test_Session(t *testing.T) { // get keys keys := sess.Keys() - require.Equal(t, []string{}, keys) + require.Equal(t, []any{}, keys) // get value name := sess.Get("name") @@ -60,7 +60,7 @@ func Test_Session(t *testing.T) { require.Equal(t, "john", name) keys = sess.Keys() - require.Equal(t, []string{"name"}, keys) + require.Equal(t, []any{"name"}, keys) // delete key sess.Delete("name") @@ -71,7 +71,7 @@ func Test_Session(t *testing.T) { // get keys keys = sess.Keys() - require.Equal(t, []string{}, keys) + require.Equal(t, []any{}, keys) // get id id := sess.ID() @@ -327,6 +327,181 @@ func Test_Session_Store_Reset(t *testing.T) { require.Nil(t, sess.Get("hello")) } +func Test_Session_KeyTypes(t *testing.T) { + t.Parallel() + + // session store + store := NewStore() + // fiber instance + app := fiber.New() + // fiber context + ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) + + // get session + sess, err := store.Get(ctx) + require.NoError(t, err) + require.True(t, sess.Fresh()) + + type Person struct { + Name string + } + + type unexportedKey int + + // register non-default types + store.RegisterType(Person{}) + store.RegisterType(unexportedKey(0)) + + type unregisteredKeyType int + type unregisteredValueType int + + // verify unregistered keys types are not allowed + var ( + unregisteredKey unregisteredKeyType = 0 + unregisteredValue unregisteredValueType = 1 + ) + sess.Set(unregisteredKey, "test") + err = sess.Save() + require.Error(t, err) + sess.Delete(unregisteredKey) + err = sess.Save() + require.NoError(t, err) + sess.Set("abc", unregisteredValue) + err = sess.Save() + require.Error(t, err) + sess.Delete("abc") + err = sess.Save() + require.NoError(t, err) + + sess.Reset() + + var ( + kbool = true + kstring = "str" + kint = 13 + kint8 int8 = 13 + kint16 int16 = 13 + kint32 int32 = 13 + kint64 int64 = 13 + kuint uint = 13 + kuint8 uint8 = 13 + kuint16 uint16 = 13 + kuint32 uint32 = 13 + kuint64 uint64 = 13 + kuintptr uintptr = 13 + kbyte byte = 'k' + krune = 'k' + kfloat32 float32 = 13 + kfloat64 float64 = 13 + kcomplex64 complex64 = 13 + kcomplex128 complex128 = 13 + kuser = Person{Name: "John"} + kunexportedKey = unexportedKey(13) + ) + + var ( + vbool = true + vstring = "str" + vint = 13 + vint8 int8 = 13 + vint16 int16 = 13 + vint32 int32 = 13 + vint64 int64 = 13 + vuint uint = 13 + vuint8 uint8 = 13 + vuint16 uint16 = 13 + vuint32 uint32 = 13 + vuint64 uint64 = 13 + vuintptr uintptr = 13 + vbyte byte = 'k' + vrune = 'k' + vfloat32 float32 = 13 + vfloat64 float64 = 13 + vcomplex64 complex64 = 13 + vcomplex128 complex128 = 13 + vuser = Person{Name: "John"} + vunexportedKey = unexportedKey(13) + ) + + keys := []any{ + kbool, + kstring, + kint, + kint8, + kint16, + kint32, + kint64, + kuint, + kuint8, + kuint16, + kuint32, + kuint64, + kuintptr, + kbyte, + krune, + kfloat32, + kfloat64, + kcomplex64, + kcomplex128, + kuser, + kunexportedKey, + } + + values := []any{ + vbool, + vstring, + vint, + vint8, + vint16, + vint32, + vint64, + vuint, + vuint8, + vuint16, + vuint32, + vuint64, + vuintptr, + vbyte, + vrune, + vfloat32, + vfloat64, + vcomplex64, + vcomplex128, + vuser, + vunexportedKey, + } + + // loop test all key value pairs + for i, key := range keys { + sess.Set(key, values[i]) + } + + id := sess.ID() + ctx.Request().Header.SetCookie(store.sessionName, id) + // save session + err = sess.Save() + require.NoError(t, err) + + sess.Release() + app.ReleaseCtx(ctx) + ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(ctx) + ctx.Request().Header.SetCookie(store.sessionName, id) + + // get session + sess, err = store.Get(ctx) + require.NoError(t, err) + defer sess.Release() + require.False(t, sess.Fresh()) + + // loop test all key value pairs + for i, key := range keys { + // get value + result := sess.Get(key) + require.Equal(t, values[i], result) + } +} + // go test -run Test_Session_Save func Test_Session_Save(t *testing.T) { t.Parallel() @@ -661,7 +836,7 @@ func Test_Session_Reset(t *testing.T) { // Check that the session data has been reset keys := acquiredSession.Keys() - require.Equal(t, []string{}, keys) + require.Equal(t, []any{}, keys) // Set a new value for 'name' and check that it's updated acquiredSession.Set("name", "john") diff --git a/middleware/session/store.go b/middleware/session/store.go index 9d525328a3..f550e81b46 100644 --- a/middleware/session/store.go +++ b/middleware/session/store.go @@ -4,9 +4,11 @@ import ( "encoding/gob" "errors" "fmt" + "time" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/internal/storage/memory" + "github.com/gofiber/fiber/v3/log" "github.com/gofiber/utils/v2" ) @@ -143,7 +145,6 @@ func (s *Store) getSession(c fiber.Ctx) (*Session, error) { sess := acquireSession() sess.mu.Lock() - defer sess.mu.Unlock() sess.ctx = c sess.config = s @@ -158,6 +159,15 @@ func (s *Store) getSession(c fiber.Ctx) (*Session, error) { return nil, fmt.Errorf("failed to decode session data: %w", err) } } + sess.mu.Unlock() + + if fresh && s.AbsoluteTimeout > 0 { + sess.setExpiration(time.Now().Add(s.AbsoluteTimeout)) + } else if sess.isExpired() { + if err := sess.Reset(); err != nil { + return nil, fmt.Errorf("failed to reset session: %w", err) + } + } return sess, nil } @@ -281,10 +291,20 @@ func (s *Store) GetSessionByID(id string) (*Session, error) { sess.config = s sess.data.Lock() - defer sess.data.Unlock() - if err := sess.decodeSessionData(rawData); err != nil { + decodeErr := sess.decodeSessionData(rawData) + sess.data.Unlock() + if decodeErr != nil { return nil, fmt.Errorf("failed to decode session data: %w", err) } + if s.AbsoluteTimeout > 0 { + if sess.isExpired() { + if err := sess.Destroy(); err != nil { + log.Errorf("failed to destroy expired session: %v", err) + } + return nil, ErrSessionIDNotFoundInStore + } + } + return sess, nil } From 3ac9b68f2a45df353c4d7292f9d331754cbdadd3 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Thu, 26 Sep 2024 14:40:58 -0300 Subject: [PATCH 57/79] fix(middleware/session): locking issues and lint errors --- docs/middleware/session.md | 6 +- docs/whats_new.md | 2 + middleware/session/data.go | 2 +- middleware/session/data_msgp.go | 355 +++++++++++++++-------------- middleware/session/middleware.go | 2 +- middleware/session/session.go | 6 +- middleware/session/session_test.go | 6 +- middleware/session/store.go | 10 +- 8 files changed, 199 insertions(+), 190 deletions(-) diff --git a/docs/middleware/session.md b/docs/middleware/session.md index 60df9f1252..46e6a04775 100644 --- a/docs/middleware/session.md +++ b/docs/middleware/session.md @@ -39,7 +39,7 @@ As of v3, we recommend using the middleware handler for session management. Howe - **Session Lifecycle Management**: The `*Store.Save` method no longer releases the instance automatically. You must manually call `sess.Release()` after using the session to manage its lifecycle properly. -- **Expiration Handling**: Previously, the `Expiration` field represented the maximum session duration before expiration. However, it would extend every time the session was saved, making its behavior a mix between session duration and session idle timeout. The `Expiration` field has been removed and replaced with the `IdleTimeout` field, which explicitly defines the session's idle timeout period. Users who need to set a maximum session duration must now implement this logic themselves using data stored in the session. +- **Expiration Handling**: Previously, the `Expiration` field represented the maximum session duration before expiration. However, it would extend every time the session was saved, making its behavior a mix between session duration and session idle timeout. The `Expiration` field has been removed and replaced with `IdleTimeout` and `AbsoluteTimeout` fields, which explicitly defines the session's idle and absolute timeout periods. For more details about Fiber v3, see [What’s New](https://github.com/gofiber/fiber/blob/main/docs/whats_new.md). @@ -52,7 +52,7 @@ To convert a v2 example to use the v3 legacy approach, follow these steps: 3. **Release Session**: Ensure that you call `sess.Release()` after you are done with the session to manage its lifecycle. :::note -When using the legacy approach, the IdleTimeout will only be updated when the session is saved. +When using the legacy approach, the IdleTimeout will be updated when the session is saved. ::: #### Example Conversion @@ -427,6 +427,7 @@ func main() { | **CookiePath** | `string` | The path scope of the session cookie. | `"/"` | | **CookieSameSite** | `string` | The SameSite attribute of the session cookie. | `"Lax"` | | **IdleTimeout** | `time.Duration` | Maximum duration of inactivity before session expires. | `30 * time.Minute` | +| **AbsoluteTimeout** | `time.Duration` | Maximum duration before session expires. | `0` (no expiration) | | **CookieSecure** | `bool` | Ensures session cookie is only sent over HTTPS. | `false` | | **CookieHTTPOnly** | `bool` | Ensures session cookie is not accessible to JavaScript (HTTP only). | `true` | | **CookieSessionOnly** | `bool` | Prevents session cookie from being saved after the session ends (cookie expires on close). | `false` | @@ -445,6 +446,7 @@ session.Config{ CookiePath: "", CookieSameSite: "Lax", IdleTimeout: 30 * time.Minute, + AbsoluteTimeout: 0, CookieSecure: false, CookieHTTPOnly: false, CookieSessionOnly: false, diff --git a/docs/whats_new.md b/docs/whats_new.md index 400d8a4fe9..577223b9a0 100644 --- a/docs/whats_new.md +++ b/docs/whats_new.md @@ -324,6 +324,8 @@ The Session middleware has undergone key changes in v3 to improve functionality - **Idle Timeout**: The `Expiration` field has been replaced with `IdleTimeout`, which strictly handles session inactivity. If you require a maximum session duration, you'll need to implement it within your own session data. +- **Absolute Timeout**: The `AbsoluteTimeout` field has been added. If you need to set an absolute session timeout, you can use this field to define the duration. + For more details on these changes and migration instructions, check the [Session Middleware Migration Guide](./middleware/session.md#migration-guide). ### Filesystem diff --git a/middleware/session/data.go b/middleware/session/data.go index 2c51ac39bf..052e43bc1b 100644 --- a/middleware/session/data.go +++ b/middleware/session/data.go @@ -74,7 +74,7 @@ func (d *data) Get(key any) any { // Usage: // // d.Set("key", "value") -func (d *data) Set(key any, value any) { +func (d *data) Set(key, value any) { d.Lock() defer d.Unlock() d.Data[key] = value diff --git a/middleware/session/data_msgp.go b/middleware/session/data_msgp.go index 934f16ba91..483eea43ce 100644 --- a/middleware/session/data_msgp.go +++ b/middleware/session/data_msgp.go @@ -3,7 +3,7 @@ package session // Code generated by github.com/tinylib/msgp DO NOT EDIT. import ( - "github.com/tinylib/msgp/msgp" + "github.com/tinylib/msgp/msgp" ) // DecodeMsg implements msgp.Decodable @@ -17,63 +17,64 @@ import ( // - error: An error if the decoding fails. // // Usage: -// err := d.DecodeMsg(reader) +// +// err := d.DecodeMsg(reader) func (z *data) DecodeMsg(dc *msgp.Reader) (err error) { - var field []byte - _ = field - var zb0001 uint32 - zb0001, err = dc.ReadMapHeader() - if err != nil { - err = msgp.WrapError(err) - return - } - for zb0001 > 0 { - zb0001-- - field, err = dc.ReadMapKeyPtr() - if err != nil { - err = msgp.WrapError(err) - return - } - switch msgp.UnsafeString(field) { - case "Data": - var zb0002 uint32 - zb0002, err = dc.ReadMapHeader() - if err != nil { - err = msgp.WrapError(err, "Data") - return - } - if z.Data == nil { - z.Data = make(map[any]any, zb0002) - } else if len(z.Data) > 0 { - for key := range z.Data { - delete(z.Data, key) - } - } - for zb0002 > 0 { - zb0002-- - var za0001 any - var za0002 any - za0001, err = dc.ReadIntf() - if err != nil { - err = msgp.WrapError(err, "Data") - return - } - za0002, err = dc.ReadIntf() - if err != nil { - err = msgp.WrapError(err, "Data", za0001) - return - } - z.Data[za0001] = za0002 - } - default: - err = dc.Skip() - if err != nil { - err = msgp.WrapError(err) - return - } - } - } - return + var field []byte + _ = field + var zb0001 uint32 + zb0001, err = dc.ReadMapHeader() + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, err = dc.ReadMapKeyPtr() + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "Data": + var zb0002 uint32 + zb0002, err = dc.ReadMapHeader() + if err != nil { + err = msgp.WrapError(err, "Data") + return + } + if z.Data == nil { + z.Data = make(map[any]any, zb0002) + } else if len(z.Data) > 0 { + for key := range z.Data { + delete(z.Data, key) + } + } + for zb0002 > 0 { + zb0002-- + var za0001 any + var za0002 any + za0001, err = dc.ReadIntf() + if err != nil { + err = msgp.WrapError(err, "Data") + return + } + za0002, err = dc.ReadIntf() + if err != nil { + err = msgp.WrapError(err, "Data", za0001) + return + } + z.Data[za0001] = za0002 + } + default: + err = dc.Skip() + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + return } // EncodeMsg implements msgp.Encodable @@ -87,36 +88,37 @@ func (z *data) DecodeMsg(dc *msgp.Reader) (err error) { // - error: An error if the encoding fails. // // Usage: -// err := d.EncodeMsg(writer) +// +// err := d.EncodeMsg(writer) func (z *data) EncodeMsg(en *msgp.Writer) (err error) { - // map header, size 1 - // write "Data" - err = en.Append(0x81, 0xa4, 0x44, 0x61, 0x74, 0x61) - if err != nil { - return - } - err = en.WriteMapHeader(uint32(len(z.Data))) - if err != nil { - err = msgp.WrapError(err, "Data") - return - } - for za0001, za0002 := range z.Data { - keyStr, ok := za0001.(string) - if !ok { - return msgp.WrapError(err, "Data", za0001) - } - err = en.WriteString(keyStr) - if err != nil { - err = msgp.WrapError(err, "Data") - return - } - err = en.WriteIntf(za0002) - if err != nil { - err = msgp.WrapError(err, "Data", keyStr) - return - } - } - return + // map header, size 1 + // write "Data" + err = en.Append(0x81, 0xa4, 0x44, 0x61, 0x74, 0x61) + if err != nil { + return + } + err = en.WriteMapHeader(uint32(len(z.Data))) + if err != nil { + err = msgp.WrapError(err, "Data") + return + } + for za0001, za0002 := range z.Data { + keyStr, ok := za0001.(string) + if !ok { + return msgp.WrapError(err, "Data", za0001) + } + err = en.WriteString(keyStr) + if err != nil { + err = msgp.WrapError(err, "Data") + return + } + err = en.WriteIntf(za0002) + if err != nil { + err = msgp.WrapError(err, "Data", keyStr) + return + } + } + return } // MarshalMsg implements msgp.Marshaler @@ -131,26 +133,27 @@ func (z *data) EncodeMsg(en *msgp.Writer) (err error) { // - error: An error if the marshaling fails. // // Usage: -// b, err := d.MarshalMsg(nil) +// +// b, err := d.MarshalMsg(nil) func (z *data) MarshalMsg(b []byte) (o []byte, err error) { - o = msgp.Require(b, z.Msgsize()) - // map header, size 1 - // string "Data" - o = append(o, 0x81, 0xa4, 0x44, 0x61, 0x74, 0x61) - o = msgp.AppendMapHeader(o, uint32(len(z.Data))) - for za0001, za0002 := range z.Data { - keyStr, ok := za0001.(string) - if !ok { - return nil, msgp.WrapError(err, "Data", za0001) - } - o = msgp.AppendString(o, keyStr) - o, err = msgp.AppendIntf(o, za0002) - if err != nil { - err = msgp.WrapError(err, "Data", keyStr) - return - } - } - return + o = msgp.Require(b, z.Msgsize()) + // map header, size 1 + // string "Data" + o = append(o, 0x81, 0xa4, 0x44, 0x61, 0x74, 0x61) + o = msgp.AppendMapHeader(o, uint32(len(z.Data))) + for za0001, za0002 := range z.Data { + keyStr, ok := za0001.(string) + if !ok { + return nil, msgp.WrapError(err, "Data", za0001) + } + o = msgp.AppendString(o, keyStr) + o, err = msgp.AppendIntf(o, za0002) + if err != nil { + err = msgp.WrapError(err, "Data", keyStr) + return + } + } + return } // UnmarshalMsg implements msgp.Unmarshaler @@ -165,64 +168,65 @@ func (z *data) MarshalMsg(b []byte) (o []byte, err error) { // - error: An error if the unmarshaling fails. // // Usage: -// b, err := d.UnmarshalMsg(bts) +// +// b, err := d.UnmarshalMsg(bts) func (z *data) UnmarshalMsg(bts []byte) (o []byte, err error) { - var field []byte - _ = field - var zb0001 uint32 - zb0001, bts, err = msgp.ReadMapHeaderBytes(bts) - if err != nil { - err = msgp.WrapError(err) - return - } - for zb0001 > 0 { - zb0001-- - field, bts, err = msgp.ReadMapKeyZC(bts) - if err != nil { - err = msgp.WrapError(err) - return - } - switch msgp.UnsafeString(field) { - case "Data": - var zb0002 uint32 - zb0002, bts, err = msgp.ReadMapHeaderBytes(bts) - if err != nil { - err = msgp.WrapError(err, "Data") - return - } - if z.Data == nil { - z.Data = make(map[any]any, zb0002) - } else if len(z.Data) > 0 { - for key := range z.Data { - delete(z.Data, key) - } - } - for zb0002 > 0 { - var za0001 any - var za0002 any - zb0002-- - za0001, bts, err = msgp.ReadIntfBytes(bts) - if err != nil { - err = msgp.WrapError(err, "Data") - return - } - za0002, bts, err = msgp.ReadIntfBytes(bts) - if err != nil { - err = msgp.WrapError(err, "Data", za0001) - return - } - z.Data[za0001] = za0002 - } - default: - bts, err = msgp.Skip(bts) - if err != nil { - err = msgp.WrapError(err) - return - } - } - } - o = bts - return + var field []byte + _ = field + var zb0001 uint32 + zb0001, bts, err = msgp.ReadMapHeaderBytes(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, bts, err = msgp.ReadMapKeyZC(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "Data": + var zb0002 uint32 + zb0002, bts, err = msgp.ReadMapHeaderBytes(bts) + if err != nil { + err = msgp.WrapError(err, "Data") + return + } + if z.Data == nil { + z.Data = make(map[any]any, zb0002) + } else if len(z.Data) > 0 { + for key := range z.Data { + delete(z.Data, key) + } + } + for zb0002 > 0 { + var za0001 any + var za0002 any + zb0002-- + za0001, bts, err = msgp.ReadIntfBytes(bts) + if err != nil { + err = msgp.WrapError(err, "Data") + return + } + za0002, bts, err = msgp.ReadIntfBytes(bts) + if err != nil { + err = msgp.WrapError(err, "Data", za0001) + return + } + z.Data[za0001] = za0002 + } + default: + bts, err = msgp.Skip(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + o = bts + return } // Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message @@ -233,17 +237,18 @@ func (z *data) UnmarshalMsg(bts []byte) (o []byte, err error) { // - int: The estimated size in bytes. // // Usage: -// size := d.Msgsize() +// +// size := d.Msgsize() func (z *data) Msgsize() (s int) { - s = 1 + 5 + msgp.MapHeaderSize - if z.Data != nil { - for za0001, za0002 := range z.Data { - keyStr, ok := za0001.(string) - if !ok { - continue - } - s += msgp.StringPrefixSize + len(keyStr) + msgp.GuessSize(za0002) - } - } - return -} \ No newline at end of file + s = 1 + 5 + msgp.MapHeaderSize + if z.Data != nil { + for za0001, za0002 := range z.Data { + keyStr, ok := za0001.(string) + if !ok { + continue + } + s += msgp.StringPrefixSize + len(keyStr) + msgp.GuessSize(za0002) + } + } + return +} diff --git a/middleware/session/middleware.go b/middleware/session/middleware.go index a4fb6128bf..4d38800451 100644 --- a/middleware/session/middleware.go +++ b/middleware/session/middleware.go @@ -195,7 +195,7 @@ func FromContext(c fiber.Ctx) *Middleware { // Usage: // // m.Set("key", "value") -func (m *Middleware) Set(key any, value any) { +func (m *Middleware) Set(key, value any) { m.mu.Lock() defer m.mu.Unlock() diff --git a/middleware/session/session.go b/middleware/session/session.go index 286602c233..1f443420a9 100644 --- a/middleware/session/session.go +++ b/middleware/session/session.go @@ -151,7 +151,7 @@ func (s *Session) Get(key any) any { // Usage: // // s.Set("key", "value") -func (s *Session) Set(key any, val any) { +func (s *Session) Set(key, val any) { if s.data == nil { return } @@ -447,8 +447,6 @@ func (s *Session) decodeSessionData(rawData []byte) error { // // expiration := s.expiration() func (s *Session) expiration() time.Time { - s.mu.RLock() - defer s.mu.RUnlock() expiration, ok := s.Get(expirationKey).(time.Time) if ok { return expiration @@ -476,7 +474,5 @@ func (s *Session) isExpired() bool { // // s.setExpiration(time.Now().Add(time.Hour)) func (s *Session) setExpiration(expiration time.Time) { - s.mu.Lock() - defer s.mu.Unlock() s.Set(expirationKey, expiration) } diff --git a/middleware/session/session_test.go b/middleware/session/session_test.go index b503cb9c99..9cf0976798 100644 --- a/middleware/session/session_test.go +++ b/middleware/session/session_test.go @@ -357,8 +357,8 @@ func Test_Session_KeyTypes(t *testing.T) { // verify unregistered keys types are not allowed var ( - unregisteredKey unregisteredKeyType = 0 - unregisteredValue unregisteredValueType = 1 + unregisteredKey unregisteredKeyType + unregisteredValue unregisteredValueType ) sess.Set(unregisteredKey, "test") err = sess.Save() @@ -373,7 +373,7 @@ func Test_Session_KeyTypes(t *testing.T) { err = sess.Save() require.NoError(t, err) - sess.Reset() + require.NoError(t, sess.Reset()) var ( kbool = true diff --git a/middleware/session/store.go b/middleware/session/store.go index f550e81b46..df11b09882 100644 --- a/middleware/session/store.go +++ b/middleware/session/store.go @@ -154,11 +154,14 @@ func (s *Store) getSession(c fiber.Ctx) (*Session, error) { // Decode session data if found if rawData != nil { sess.data.Lock() - defer sess.data.Unlock() - if err := sess.decodeSessionData(rawData); err != nil { + err := sess.decodeSessionData(rawData) + sess.data.Unlock() + if err != nil { + sess.mu.Unlock() return nil, fmt.Errorf("failed to decode session data: %w", err) } } + sess.mu.Unlock() if fresh && s.AbsoluteTimeout > 0 { @@ -285,7 +288,6 @@ func (s *Store) GetSessionByID(id string) (*Session, error) { sess := acquireSession() sess.mu.Lock() - defer sess.mu.Unlock() sess.id = id sess.config = s @@ -294,8 +296,10 @@ func (s *Store) GetSessionByID(id string) (*Session, error) { decodeErr := sess.decodeSessionData(rawData) sess.data.Unlock() if decodeErr != nil { + sess.mu.Unlock() return nil, fmt.Errorf("failed to decode session data: %w", err) } + sess.mu.Unlock() if s.AbsoluteTimeout > 0 { if sess.isExpired() { From bc95c6aa6659b43ecea8ae7c392df05abeb05121 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Thu, 26 Sep 2024 14:53:14 -0300 Subject: [PATCH 58/79] chore(middleware/session): Regenerate code in data_msgp.go --- middleware/session/data_msgp.go | 182 ++------------------------------ 1 file changed, 10 insertions(+), 172 deletions(-) diff --git a/middleware/session/data_msgp.go b/middleware/session/data_msgp.go index 483eea43ce..a640e141b8 100644 --- a/middleware/session/data_msgp.go +++ b/middleware/session/data_msgp.go @@ -7,18 +7,6 @@ import ( ) // DecodeMsg implements msgp.Decodable -// -// This method decodes the session data from the provided msgp.Reader. -// -// Parameters: -// - dc: The msgp.Reader to decode from. -// -// Returns: -// - error: An error if the decoding fails. -// -// Usage: -// -// err := d.DecodeMsg(reader) func (z *data) DecodeMsg(dc *msgp.Reader) (err error) { var field []byte _ = field @@ -36,36 +24,6 @@ func (z *data) DecodeMsg(dc *msgp.Reader) (err error) { return } switch msgp.UnsafeString(field) { - case "Data": - var zb0002 uint32 - zb0002, err = dc.ReadMapHeader() - if err != nil { - err = msgp.WrapError(err, "Data") - return - } - if z.Data == nil { - z.Data = make(map[any]any, zb0002) - } else if len(z.Data) > 0 { - for key := range z.Data { - delete(z.Data, key) - } - } - for zb0002 > 0 { - zb0002-- - var za0001 any - var za0002 any - za0001, err = dc.ReadIntf() - if err != nil { - err = msgp.WrapError(err, "Data") - return - } - za0002, err = dc.ReadIntf() - if err != nil { - err = msgp.WrapError(err, "Data", za0001) - return - } - z.Data[za0001] = za0002 - } default: err = dc.Skip() if err != nil { @@ -78,98 +36,26 @@ func (z *data) DecodeMsg(dc *msgp.Reader) (err error) { } // EncodeMsg implements msgp.Encodable -// -// This method encodes the session data to the provided msgp.Writer. -// -// Parameters: -// - en: The msgp.Writer to encode to. -// -// Returns: -// - error: An error if the encoding fails. -// -// Usage: -// -// err := d.EncodeMsg(writer) -func (z *data) EncodeMsg(en *msgp.Writer) (err error) { - // map header, size 1 - // write "Data" - err = en.Append(0x81, 0xa4, 0x44, 0x61, 0x74, 0x61) - if err != nil { - return - } - err = en.WriteMapHeader(uint32(len(z.Data))) +func (z data) EncodeMsg(en *msgp.Writer) (err error) { + // map header, size 0 + _ = z + err = en.Append(0x80) if err != nil { - err = msgp.WrapError(err, "Data") return } - for za0001, za0002 := range z.Data { - keyStr, ok := za0001.(string) - if !ok { - return msgp.WrapError(err, "Data", za0001) - } - err = en.WriteString(keyStr) - if err != nil { - err = msgp.WrapError(err, "Data") - return - } - err = en.WriteIntf(za0002) - if err != nil { - err = msgp.WrapError(err, "Data", keyStr) - return - } - } return } // MarshalMsg implements msgp.Marshaler -// -// This method marshals the session data into a byte slice. -// -// Parameters: -// - b: The byte slice to marshal into. -// -// Returns: -// - []byte: The marshaled byte slice. -// - error: An error if the marshaling fails. -// -// Usage: -// -// b, err := d.MarshalMsg(nil) -func (z *data) MarshalMsg(b []byte) (o []byte, err error) { +func (z data) MarshalMsg(b []byte) (o []byte, err error) { o = msgp.Require(b, z.Msgsize()) - // map header, size 1 - // string "Data" - o = append(o, 0x81, 0xa4, 0x44, 0x61, 0x74, 0x61) - o = msgp.AppendMapHeader(o, uint32(len(z.Data))) - for za0001, za0002 := range z.Data { - keyStr, ok := za0001.(string) - if !ok { - return nil, msgp.WrapError(err, "Data", za0001) - } - o = msgp.AppendString(o, keyStr) - o, err = msgp.AppendIntf(o, za0002) - if err != nil { - err = msgp.WrapError(err, "Data", keyStr) - return - } - } + // map header, size 0 + _ = z + o = append(o, 0x80) return } // UnmarshalMsg implements msgp.Unmarshaler -// -// This method unmarshals the session data from a byte slice. -// -// Parameters: -// - bts: The byte slice to unmarshal from. -// -// Returns: -// - []byte: The remaining byte slice after unmarshaling. -// - error: An error if the unmarshaling fails. -// -// Usage: -// -// b, err := d.UnmarshalMsg(bts) func (z *data) UnmarshalMsg(bts []byte) (o []byte, err error) { var field []byte _ = field @@ -187,36 +73,6 @@ func (z *data) UnmarshalMsg(bts []byte) (o []byte, err error) { return } switch msgp.UnsafeString(field) { - case "Data": - var zb0002 uint32 - zb0002, bts, err = msgp.ReadMapHeaderBytes(bts) - if err != nil { - err = msgp.WrapError(err, "Data") - return - } - if z.Data == nil { - z.Data = make(map[any]any, zb0002) - } else if len(z.Data) > 0 { - for key := range z.Data { - delete(z.Data, key) - } - } - for zb0002 > 0 { - var za0001 any - var za0002 any - zb0002-- - za0001, bts, err = msgp.ReadIntfBytes(bts) - if err != nil { - err = msgp.WrapError(err, "Data") - return - } - za0002, bts, err = msgp.ReadIntfBytes(bts) - if err != nil { - err = msgp.WrapError(err, "Data", za0001) - return - } - z.Data[za0001] = za0002 - } default: bts, err = msgp.Skip(bts) if err != nil { @@ -230,25 +86,7 @@ func (z *data) UnmarshalMsg(bts []byte) (o []byte, err error) { } // Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message -// -// This method returns the estimated size of the serialized session data. -// -// Returns: -// - int: The estimated size in bytes. -// -// Usage: -// -// size := d.Msgsize() -func (z *data) Msgsize() (s int) { - s = 1 + 5 + msgp.MapHeaderSize - if z.Data != nil { - for za0001, za0002 := range z.Data { - keyStr, ok := za0001.(string) - if !ok { - continue - } - s += msgp.StringPrefixSize + len(keyStr) + msgp.GuessSize(za0002) - } - } +func (z data) Msgsize() (s int) { + s = 1 return } From 6bba849c8191d4fbe33a48c74ecedd1ffb9046f6 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Thu, 26 Sep 2024 15:14:41 -0300 Subject: [PATCH 59/79] refactor(middleware/session): rename GetSessionByID to GetByID This commit also includes changes to the session_test.go and store_test.go files to add test cases for the new GetByID method. --- docs/middleware/session.md | 2 +- middleware/session/session_test.go | 67 +++++++++++++++++++++++++++++- middleware/session/store.go | 17 +++++--- middleware/session/store_test.go | 8 ++-- 4 files changed, 83 insertions(+), 11 deletions(-) diff --git a/docs/middleware/session.md b/docs/middleware/session.md index 46e6a04775..628f21daa0 100644 --- a/docs/middleware/session.md +++ b/docs/middleware/session.md @@ -219,9 +219,9 @@ func (s *Session) SetIdleTimeout(idleTimeout time.Duration) ```go func (*Store) RegisterType(i any) func (s *Store) Get(c fiber.Ctx) (*Session, error) +func (s *Store) GetByID(id string) (*Session, error) func (s *Store) Reset() error func (s *Store) Delete(id string) error -func (s *Store) GetSessionByID(id string) (*Session, error) ``` ## Examples diff --git a/middleware/session/session_test.go b/middleware/session/session_test.go index 9cf0976798..9840fe92c6 100644 --- a/middleware/session/session_test.go +++ b/middleware/session/session_test.go @@ -554,7 +554,7 @@ func Test_Session_Save(t *testing.T) { }) } -func Test_Session_Save_Expiration(t *testing.T) { +func Test_Session_Save_IdleTimeout(t *testing.T) { t.Parallel() t.Run("save to cookie", func(t *testing.T) { @@ -614,6 +614,71 @@ func Test_Session_Save_Expiration(t *testing.T) { }) } +func Test_Session_Save_Absolute(t *testing.T) { + t.Parallel() + + t.Run("save to cookie", func(t *testing.T) { + t.Parallel() + + const absoluteTimeout = 5 * time.Second + // session store + store := NewStore(Config{ + AbsoluteTimeout: absoluteTimeout, + }) + // fiber instance + app := fiber.New() + // fiber context + ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(ctx) + + // get session + sess, err := store.Get(ctx) + require.NoError(t, err) + + // set value + sess.Set("name", "john") + + token := sess.ID() + + // save session + err = sess.Save() + require.NoError(t, err) + + sess.Release() + app.ReleaseCtx(ctx) + ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) + + // here you need to get the old session yet + ctx.Request().Header.SetCookie(store.sessionName, token) + sess, err = store.Get(ctx) + require.NoError(t, err) + require.Equal(t, "john", sess.Get("name")) + + // just to make sure the session has been expired + time.Sleep(absoluteTimeout + (10 * time.Millisecond)) + + sess.Release() + + app.ReleaseCtx(ctx) + ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(ctx) + + // here you should get a new session + ctx.Request().Header.SetCookie(store.sessionName, token) + sess, err = store.Get(ctx) + defer sess.Release() + require.NoError(t, err) + require.Nil(t, sess.Get("name")) + require.NotEqual(t, sess.ID(), token) + + // try to get expired session by id + sess, err = store.GetByID(token) + require.Error(t, err) + require.ErrorIs(t, err, ErrSessionIDNotFoundInStore) + require.Nil(t, sess) + }) +} + // go test -run Test_Session_Destroy func Test_Session_Destroy(t *testing.T) { t.Parallel() diff --git a/middleware/session/store.go b/middleware/session/store.go index df11b09882..d11f33f7e2 100644 --- a/middleware/session/store.go +++ b/middleware/session/store.go @@ -50,9 +50,16 @@ func NewStore(config ...Config) *Store { cfg.Storage = memory.New() } - return &Store{ + store := &Store{ Config: cfg, } + + if cfg.AbsoluteTimeout > 0 { + store.RegisterType(expirationKey) + store.RegisterType(time.Time{}) + } + + return store } // RegisterType registers a custom type for encoding/decoding into any storage provider. @@ -245,7 +252,7 @@ func (s *Store) Delete(id string) error { return s.Storage.Delete(id) } -// GetSessionByID retrieves a session by its ID from the storage. +// GetByID retrieves a session by its ID from the storage. // If the session is not found, it returns nil and an error. // // Note: @@ -256,7 +263,7 @@ func (s *Store) Delete(id string) error { // - Be aware of possible collisions if you are also using the session in a middleware. // // Usage: -// - If you modify a session returned by GetSession, you must call session.Save() to persist the changes. +// - If you modify a session returned by GetByID, you must call session.Save() to persist the changes. // - When you are done with the session, you should call session.Release() to release the session back to the pool. // // Parameters: @@ -268,11 +275,11 @@ func (s *Store) Delete(id string) error { // // Usage: // -// sess, err := store.GetSessionByID(id) +// sess, err := store.GetByID(id) // if err != nil { // // handle error // } -func (s *Store) GetSessionByID(id string) (*Session, error) { +func (s *Store) GetByID(id string) (*Session, error) { if id == "" { return nil, ErrEmptySessionID } diff --git a/middleware/session/store_test.go b/middleware/session/store_test.go index f57cc38fc2..3d2395e2fb 100644 --- a/middleware/session/store_test.go +++ b/middleware/session/store_test.go @@ -155,14 +155,14 @@ func TestStore_Delete(t *testing.T) { }) } -func Test_Store_GetSessionByID(t *testing.T) { +func Test_Store_GetByID(t *testing.T) { t.Parallel() // Create a new store store := NewStore() t.Run("empty session ID", func(t *testing.T) { t.Parallel() - sess, err := store.GetSessionByID("") + sess, err := store.GetByID("") require.Error(t, err) require.Nil(t, sess) require.Equal(t, ErrEmptySessionID, err) @@ -170,7 +170,7 @@ func Test_Store_GetSessionByID(t *testing.T) { t.Run("non-existent session ID", func(t *testing.T) { t.Parallel() - sess, err := store.GetSessionByID("non-existent-session-id") + sess, err := store.GetByID("non-existent-session-id") require.Error(t, err) require.Nil(t, sess) require.Equal(t, ErrSessionIDNotFoundInStore, err) @@ -191,7 +191,7 @@ func Test_Store_GetSessionByID(t *testing.T) { require.NoError(t, err) // Retrieve the session by ID - retrievedSession, err := store.GetSessionByID(sessionID) + retrievedSession, err := store.GetByID(sessionID) require.NoError(t, err) require.NotNil(t, retrievedSession) require.Equal(t, sessionID, retrievedSession.ID()) From 281c0e17b41981c016a7bd836bedddc15a4345ed Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Thu, 26 Sep 2024 15:28:01 -0300 Subject: [PATCH 60/79] docs(middleware/session): AbsoluteTimeout --- docs/middleware/session.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/middleware/session.md b/docs/middleware/session.md index 628f21daa0..445db0e8a4 100644 --- a/docs/middleware/session.md +++ b/docs/middleware/session.md @@ -136,7 +136,7 @@ type Config struct { CookiePath string CookieSameSite string IdleTimeout time.Duration - Expiration time.Duration + AbsoluteTimeout time.Duration CookieSecure bool CookieHTTPOnly bool CookieSessionOnly bool From 3d88eceefcd096f99cd040870e7b4fed631427dd Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Thu, 26 Sep 2024 17:52:47 -0300 Subject: [PATCH 61/79] refactor(middleware/csrf): Rename Expiration to IdleTimeout --- docs/middleware/csrf.md | 15 ++++++--------- middleware/csrf/config.go | 21 ++++++--------------- middleware/csrf/csrf.go | 11 ++++------- middleware/csrf/csrf_test.go | 6 +++--- middleware/csrf/session_manager.go | 29 ++++++++++++++++++----------- middleware/session/config.go | 4 ++++ 6 files changed, 41 insertions(+), 45 deletions(-) diff --git a/docs/middleware/csrf.md b/docs/middleware/csrf.md index a034f9dfd7..8127432438 100644 --- a/docs/middleware/csrf.md +++ b/docs/middleware/csrf.md @@ -34,7 +34,7 @@ app.Use(csrf.New(csrf.Config{ KeyLookup: "header:X-Csrf-Token", CookieName: "csrf_", CookieSameSite: "Lax", - Expiration: 1 * time.Hour, + IdleTimeout: 30 * time.Minute, KeyGenerator: utils.UUIDv4, Extractor: func(c fiber.Ctx) (string, error) { ... }, })) @@ -106,15 +106,14 @@ func (h *Handler) DeleteToken(c fiber.Ctx) error | CookieSecure | `bool` | Indicates if the CSRF cookie is secure. | false | | CookieHTTPOnly | `bool` | Indicates if the CSRF cookie is HTTP-only. | false | | CookieSameSite | `string` | Value of SameSite cookie. | "Lax" | -| CookieSessionOnly | `bool` | Decides whether the cookie should last for only the browser session. Ignores Expiration if set to true. | false | -| Expiration | `time.Duration` | Expiration is the duration before the CSRF token will expire. | 1 * time.Hour | +| CookieSessionOnly | `bool` | Decides whether the cookie should last for only the browser session. (cookie expires on close). | false | +| IdleTimeout | `time.Duration` | IdleTimeout is the duration of inactivity before the CSRF token will expire. | 30 * time.Minute | | KeyGenerator | `func() string` | KeyGenerator creates a new CSRF token. | utils.UUID | | ErrorHandler | `fiber.ErrorHandler` | ErrorHandler is executed when an error is returned from fiber.Handler. | DefaultErrorHandler | | Extractor | `func(fiber.Ctx) (string, error)` | Extractor returns the CSRF token. If set, this will be used in place of an Extractor based on KeyLookup. | Extractor based on KeyLookup | | SingleUseToken | `bool` | SingleUseToken indicates if the CSRF token be destroyed and a new one generated on each use. (See TokenLifecycle) | false | | Storage | `fiber.Storage` | Store is used to store the state of the middleware. | `nil` | | Session | `*session.Store` | Session is used to store the state of the middleware. Overrides Storage if set. | `nil` | -| SessionKey | `string` | SessionKey is the key used to store the token in the session. | "csrfToken" | | TrustedOrigins | `[]string` | TrustedOrigins is a list of trusted origins for unsafe requests. This supports subdomain matching, so you can use a value like "https://*.example.com" to allow any subdomain of example.com to submit requests. | `[]` | ### Default Config @@ -124,11 +123,10 @@ var ConfigDefault = Config{ KeyLookup: "header:" + HeaderName, CookieName: "csrf_", CookieSameSite: "Lax", - Expiration: 1 * time.Hour, + IdleTimeout: 30 * time.Minute, KeyGenerator: utils.UUIDv4, ErrorHandler: defaultErrorHandler, Extractor: FromHeader(HeaderName), - SessionKey: "csrfToken", } ``` @@ -144,12 +142,11 @@ var ConfigDefault = Config{ CookieSecure: true, CookieSessionOnly: true, CookieHTTPOnly: true, - Expiration: 1 * time.Hour, + IdleTimeout: 30 * time.Minute, KeyGenerator: utils.UUIDv4, ErrorHandler: defaultErrorHandler, Extractor: FromHeader(HeaderName), Session: session.Store, - SessionKey: "csrfToken", } ``` @@ -304,7 +301,7 @@ The Referer header is automatically included in requests by all modern browsers, ## Token Lifecycle -Tokens are valid until they expire or until they are deleted. By default, tokens are valid for 1 hour, and each subsequent request extends the expiration by 1 hour. The token only expires if the user doesn't make a request for the duration of the expiration time. +Tokens are valid until they expire or until they are deleted. By default, tokens are valid for 30 minutes, and each subsequent request extends the expiration by the idle timeout. The token only expires if the user doesn't make a request for the duration of the idle timeout. ### Token Reuse diff --git a/middleware/csrf/config.go b/middleware/csrf/config.go index d37c33a58e..e718b15874 100644 --- a/middleware/csrf/config.go +++ b/middleware/csrf/config.go @@ -78,11 +78,6 @@ type Config struct { // Optional. Default value "Lax". CookieSameSite string - // SessionKey is the key used to store the token in the session - // - // Default: "csrfToken" - SessionKey string - // TrustedOrigins is a list of trusted origins for unsafe requests. // For requests that use the Origin header, the origin must match the // Host header or one of the TrustedOrigins. @@ -96,10 +91,10 @@ type Config struct { // Optional. Default: [] TrustedOrigins []string - // Expiration is the duration before csrf token will expire + // IdleTimeout is the duration of time the CSRF token is valid. // - // Optional. Default: 1 * time.Hour - Expiration time.Duration + // Optional. Default: 30 * time.Minute + IdleTimeout time.Duration // Indicates if CSRF cookie is secure. // Optional. Default value false. @@ -127,11 +122,10 @@ var ConfigDefault = Config{ KeyLookup: "header:" + HeaderName, CookieName: "csrf_", CookieSameSite: "Lax", - Expiration: 1 * time.Hour, + IdleTimeout: 30 * time.Minute, KeyGenerator: utils.UUIDv4, ErrorHandler: defaultErrorHandler, Extractor: FromHeader(HeaderName), - SessionKey: "csrfToken", } // default ErrorHandler that process return error from fiber.Handler @@ -153,8 +147,8 @@ func configDefault(config ...Config) Config { if cfg.KeyLookup == "" { cfg.KeyLookup = ConfigDefault.KeyLookup } - if int(cfg.Expiration.Seconds()) <= 0 { - cfg.Expiration = ConfigDefault.Expiration + if cfg.IdleTimeout <= 0 { + cfg.IdleTimeout = ConfigDefault.IdleTimeout } if cfg.CookieName == "" { cfg.CookieName = ConfigDefault.CookieName @@ -168,9 +162,6 @@ func configDefault(config ...Config) Config { if cfg.ErrorHandler == nil { cfg.ErrorHandler = ConfigDefault.ErrorHandler } - if cfg.SessionKey == "" { - cfg.SessionKey = ConfigDefault.SessionKey - } // Generate the correct extractor to get the token from the correct location selectors := strings.Split(cfg.KeyLookup, ":") diff --git a/middleware/csrf/csrf.go b/middleware/csrf/csrf.go index d417730416..dedfe6bd55 100644 --- a/middleware/csrf/csrf.go +++ b/middleware/csrf/csrf.go @@ -49,10 +49,7 @@ func New(config ...Config) fiber.Handler { var sessionManager *sessionManager var storageManager *storageManager if cfg.Session != nil { - // Register the Token struct in the session store - cfg.Session.RegisterType(Token{}) - - sessionManager = newSessionManager(cfg.Session, cfg.SessionKey) + sessionManager = newSessionManager(cfg.Session) } else { storageManager = newStorageManager(cfg.Storage) } @@ -220,9 +217,9 @@ func getRawFromStorage(c fiber.Ctx, token string, cfg Config, sessionManager *se // createOrExtendTokenInStorage creates or extends the token in the storage func createOrExtendTokenInStorage(c fiber.Ctx, token string, cfg Config, sessionManager *sessionManager, storageManager *storageManager) { if cfg.Session != nil { - sessionManager.setRaw(c, token, dummyValue, cfg.Expiration) + sessionManager.setRaw(c, token, dummyValue, cfg.IdleTimeout) } else { - storageManager.setRaw(token, dummyValue, cfg.Expiration) + storageManager.setRaw(token, dummyValue, cfg.IdleTimeout) } } @@ -237,7 +234,7 @@ func deleteTokenFromStorage(c fiber.Ctx, token string, cfg Config, sessionManage // Update CSRF cookie // if expireCookie is true, the cookie will expire immediately func updateCSRFCookie(c fiber.Ctx, cfg Config, token string) { - setCSRFCookie(c, cfg, token, cfg.Expiration) + setCSRFCookie(c, cfg, token, cfg.IdleTimeout) } func expireCSRFCookie(c fiber.Ctx, cfg Config) { diff --git a/middleware/csrf/csrf_test.go b/middleware/csrf/csrf_test.go index 0e486d4830..090082f4d8 100644 --- a/middleware/csrf/csrf_test.go +++ b/middleware/csrf/csrf_test.go @@ -217,7 +217,7 @@ func Test_CSRF_ExpiredToken(t *testing.T) { app := fiber.New() app.Use(New(Config{ - Expiration: 1 * time.Second, + IdleTimeout: 1 * time.Second, })) app.Post("/", func(c fiber.Ctx) error { @@ -284,8 +284,8 @@ func Test_CSRF_ExpiredToken_WithSession(t *testing.T) { // middleware config config := Config{ - Session: store, - Expiration: 1 * time.Second, + Session: store, + IdleTimeout: 1 * time.Second, } // middleware diff --git a/middleware/csrf/session_manager.go b/middleware/csrf/session_manager.go index 143748d6cd..8961c6a542 100644 --- a/middleware/csrf/session_manager.go +++ b/middleware/csrf/session_manager.go @@ -10,17 +10,24 @@ import ( type sessionManager struct { session *session.Store - key string } -func newSessionManager(s *session.Store, k string) *sessionManager { +type sessionKeyType int + +const ( + sessionKey sessionKeyType = 0 +) + +func newSessionManager(s *session.Store) *sessionManager { // Create new storage handler - sessionManager := &sessionManager{ - key: k, - } + sessionManager := new(sessionManager) if s != nil { // Use provided storage if provided sessionManager.session = s + + // Register the sessionKeyType and Token type + s.RegisterType(sessionKeyType(0)) + s.RegisterType(Token{}) } return sessionManager } @@ -32,7 +39,7 @@ func (m *sessionManager) getRaw(c fiber.Ctx, key string, raw []byte) []byte { var ok bool if sess != nil { - token, ok = sess.Get(m.key).(Token) + token, ok = sess.Get(sessionKey).(Token) } else { // Try to get the session from the store storeSess, err := m.session.Get(c) @@ -40,7 +47,7 @@ func (m *sessionManager) getRaw(c fiber.Ctx, key string, raw []byte) []byte { // Handle error return nil } - token, ok = storeSess.Get(m.key).(Token) + token, ok = storeSess.Get(sessionKey).(Token) } if ok { @@ -58,7 +65,7 @@ func (m *sessionManager) setRaw(c fiber.Ctx, key string, raw []byte, exp time.Du sess := session.FromContext(c) if sess != nil { // the key is crucial in crsf and sometimes a reference to another value which can be reused later(pool/unsafe values concept), so a copy is made here - sess.Set(m.key, &Token{Key: key, Raw: raw, Expiration: time.Now().Add(exp)}) + sess.Set(sessionKey, Token{Key: key, Raw: raw, Expiration: time.Now().Add(exp)}) } else { // Try to get the session from the store storeSess, err := m.session.Get(c) @@ -66,7 +73,7 @@ func (m *sessionManager) setRaw(c fiber.Ctx, key string, raw []byte, exp time.Du // Handle error return } - storeSess.Set(m.key, &Token{Key: key, Raw: raw, Expiration: time.Now().Add(exp)}) + storeSess.Set(sessionKey, Token{Key: key, Raw: raw, Expiration: time.Now().Add(exp)}) if err := storeSess.Save(); err != nil { log.Warn("csrf: failed to save session: ", err) } @@ -77,7 +84,7 @@ func (m *sessionManager) setRaw(c fiber.Ctx, key string, raw []byte, exp time.Du func (m *sessionManager) delRaw(c fiber.Ctx) { sess := session.FromContext(c) if sess != nil { - sess.Delete(m.key) + sess.Delete(sessionKey) } else { // Try to get the session from the store storeSess, err := m.session.Get(c) @@ -85,7 +92,7 @@ func (m *sessionManager) delRaw(c fiber.Ctx) { // Handle error return } - storeSess.Delete(m.key) + storeSess.Delete(sessionKey) if err := storeSess.Save(); err != nil { log.Warn("csrf: failed to save session: ", err) } diff --git a/middleware/session/config.go b/middleware/session/config.go index 992173b4db..2686158e44 100644 --- a/middleware/session/config.go +++ b/middleware/session/config.go @@ -157,6 +157,10 @@ func configDefault(config ...Config) Config { if cfg.IdleTimeout <= 0 { cfg.IdleTimeout = ConfigDefault.IdleTimeout } + // Ensure AbsoluteTimeout is greater than or equal to IdleTimeout. + if cfg.AbsoluteTimeout > 0 && cfg.AbsoluteTimeout < cfg.IdleTimeout { + panic("[session] AbsoluteTimeout must be greater than or equal to IdleTimeout") + } if cfg.KeyLookup == "" { cfg.KeyLookup = ConfigDefault.KeyLookup } From 3ddfeaee4d71ea0aa6c0f5057e54a616037a5e49 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Thu, 26 Sep 2024 18:01:19 -0300 Subject: [PATCH 62/79] docs(whats-new): CSRF Rename Expiration to IdleTimeout and remove SessionKey field --- docs/whats_new.md | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/docs/whats_new.md b/docs/whats_new.md index 577223b9a0..92c5f2734e 100644 --- a/docs/whats_new.md +++ b/docs/whats_new.md @@ -30,6 +30,7 @@ Here's a quick overview of the changes in Fiber `v3`: - [🧰 Generic functions](#-generic-functions) - [🧬 Middlewares](#-middlewares) - [CORS](#cors) + - [CSRF](#csrf) - [Session](#session) - [Filesystem](#filesystem) - [Monitor](#monitor) @@ -504,6 +505,24 @@ app.Use(cors.New(cors.Config{ })) ``` +#### CSRF + +- **Field Renaming**: The `Expiration` field in the CSRF middleware configuration has been renamed to `IdleTimeout` to better describe its functionality. Additionally, the default value has been reduced from 1 hour to 30 minutes. Update your code as follows: + +```go +// Before +app.Use(csrf.New(csrf.Config{ + Expiration: 10 * time.Minute, +})) + +// After +app.Use(csrf.New(csrf.Config{ + IdleTimeout: 10 * time.Minute, +})) +``` + +- **Session Key Removal**: The `SessionKey` field has been removed from the CSRF middleware configuration. The session key is now an unexported constant within the middleware to avoid potential key collisions in the session store. + #### Filesystem You need to move filesystem middleware to static middleware due to it has been removed from the core. From c3d3f0c3255c1d1b6e825b8f9ea9faedd11e3e14 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Thu, 26 Sep 2024 18:20:31 -0300 Subject: [PATCH 63/79] refactor(middleware/session): Rename expirationKeyType to absExpirationKeyType and update related functions --- middleware/session/session.go | 33 +++++++++++++++--------------- middleware/session/session_test.go | 1 + middleware/session/store.go | 10 ++++----- 3 files changed, 23 insertions(+), 21 deletions(-) diff --git a/middleware/session/session.go b/middleware/session/session.go index 1f443420a9..35d05ec306 100644 --- a/middleware/session/session.go +++ b/middleware/session/session.go @@ -24,11 +24,11 @@ type Session struct { fresh bool // if new session } -type expirationKeyType int +type absExpirationKeyType int const ( // sessionIDContextKey is the key used to store the session ID in the context locals. - expirationKey expirationKeyType = iota + absExpirationKey absExpirationKeyType = iota ) var sessionPool = sync.Pool{ @@ -438,34 +438,35 @@ func (s *Session) decodeSessionData(rawData []byte) error { return nil } -// expiration returns the session expiration time or a zero time if not set. +// absExpiration returns the session absolute expiration time or a zero time if not set. // // Returns: -// - time.Time: The session expiration time, or a zero time if not set. +// - time.Time: The session absolute expiration time. Zero time if not set. // // Usage: // -// expiration := s.expiration() -func (s *Session) expiration() time.Time { - expiration, ok := s.Get(expirationKey).(time.Time) +// expiration := s.absExpiration() +func (s *Session) absExpiration() time.Time { + absExpiration, ok := s.Get(absExpirationKey).(time.Time) if ok { - return expiration + return absExpiration } return time.Time{} } -// isExpired returns true if the session is expired. +// isAbsExpired returns true if the session is expired. // -// If the session expiration time is zero, the session is considered to never expire. +// If the session has an absolute expiration time set, this function will return true if the +// current time is after the absolute expiration time. // // Returns: // - bool: True if the session is expired, otherwise false. -func (s *Session) isExpired() bool { - expiration := s.expiration() - return !expiration.IsZero() && time.Now().After(expiration) +func (s *Session) isAbsExpired() bool { + absExpiration := s.absExpiration() + return !absExpiration.IsZero() && time.Now().After(absExpiration) } -// setExpiration sets the session expiration time. +// setAbsoluteExpiration sets the absolute session expiration time. // // Parameters: // - expiration: The session expiration time. @@ -473,6 +474,6 @@ func (s *Session) isExpired() bool { // Usage: // // s.setExpiration(time.Now().Add(time.Hour)) -func (s *Session) setExpiration(expiration time.Time) { - s.Set(expirationKey, expiration) +func (s *Session) setAbsExpiration(absExpiration time.Time) { + s.Set(absExpirationKey, absExpiration) } diff --git a/middleware/session/session_test.go b/middleware/session/session_test.go index 9840fe92c6..ce8aaf7192 100644 --- a/middleware/session/session_test.go +++ b/middleware/session/session_test.go @@ -623,6 +623,7 @@ func Test_Session_Save_Absolute(t *testing.T) { const absoluteTimeout = 5 * time.Second // session store store := NewStore(Config{ + IdleTimeout: 5 * time.Second, AbsoluteTimeout: absoluteTimeout, }) // fiber instance diff --git a/middleware/session/store.go b/middleware/session/store.go index d11f33f7e2..4cfcc998bd 100644 --- a/middleware/session/store.go +++ b/middleware/session/store.go @@ -14,7 +14,7 @@ import ( // ErrEmptySessionID is an error that occurs when the session ID is empty. var ( - ErrEmptySessionID = errors.New("session id cannot be empty") + ErrEmptySessionID = errors.New("session ID cannot be empty") ErrSessionAlreadyLoadedByMiddleware = errors.New("session already loaded by middleware") ErrSessionIDNotFoundInStore = errors.New("session ID not found in session store") ) @@ -55,7 +55,7 @@ func NewStore(config ...Config) *Store { } if cfg.AbsoluteTimeout > 0 { - store.RegisterType(expirationKey) + store.RegisterType(absExpirationKey) store.RegisterType(time.Time{}) } @@ -172,8 +172,8 @@ func (s *Store) getSession(c fiber.Ctx) (*Session, error) { sess.mu.Unlock() if fresh && s.AbsoluteTimeout > 0 { - sess.setExpiration(time.Now().Add(s.AbsoluteTimeout)) - } else if sess.isExpired() { + sess.setAbsExpiration(time.Now().Add(s.AbsoluteTimeout)) + } else if sess.isAbsExpired() { if err := sess.Reset(); err != nil { return nil, fmt.Errorf("failed to reset session: %w", err) } @@ -309,7 +309,7 @@ func (s *Store) GetByID(id string) (*Session, error) { sess.mu.Unlock() if s.AbsoluteTimeout > 0 { - if sess.isExpired() { + if sess.isAbsExpired() { if err := sess.Destroy(); err != nil { log.Errorf("failed to destroy expired session: %v", err) } From 0e9a73e6ebba0a0faf18f833486ad1ab9a2c1b99 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Thu, 26 Sep 2024 18:29:57 -0300 Subject: [PATCH 64/79] refactor(middleware/session): rename Test_Session_Save_Absolute to Test_Session_Save_AbsoluteTimeout --- middleware/session/session_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/middleware/session/session_test.go b/middleware/session/session_test.go index ce8aaf7192..f604d3b2dd 100644 --- a/middleware/session/session_test.go +++ b/middleware/session/session_test.go @@ -614,7 +614,7 @@ func Test_Session_Save_IdleTimeout(t *testing.T) { }) } -func Test_Session_Save_Absolute(t *testing.T) { +func Test_Session_Save_AbsoluteTimeout(t *testing.T) { t.Parallel() t.Run("save to cookie", func(t *testing.T) { From a4672364ca23a9ce151214b2923e4a890a41b2c5 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Tue, 1 Oct 2024 11:33:28 -0300 Subject: [PATCH 65/79] chore(middleware/session): update as per PR comments --- docs/middleware/session.md | 4 ++++ docs/whats_new.md | 4 ++-- middleware/session/data_test.go | 2 +- middleware/session/middleware.go | 3 --- 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/docs/middleware/session.md b/docs/middleware/session.md index 445db0e8a4..e0ec04f968 100644 --- a/docs/middleware/session.md +++ b/docs/middleware/session.md @@ -41,6 +41,10 @@ As of v3, we recommend using the middleware handler for session management. Howe - **Expiration Handling**: Previously, the `Expiration` field represented the maximum session duration before expiration. However, it would extend every time the session was saved, making its behavior a mix between session duration and session idle timeout. The `Expiration` field has been removed and replaced with `IdleTimeout` and `AbsoluteTimeout` fields, which explicitly defines the session's idle and absolute timeout periods. + - **Idle Timeout**: The new `IdleTimeout`, handles session inactivity. If the session is idle for the specified duration, it will expire. The idle timeout is updated when the session is saved. If you are using the middleware handler, the idle timeout will be updated automatically. + + - **Absolute Timeout**: The `AbsoluteTimeout` field has been added. If you need to set an absolute session timeout, you can use this field to define the duration. The session will expire after the specified duration, regardless of activity. + For more details about Fiber v3, see [What’s New](https://github.com/gofiber/fiber/blob/main/docs/whats_new.md). ### Migrating v2 to v3 Example (Legacy Approach) diff --git a/docs/whats_new.md b/docs/whats_new.md index 92c5f2734e..4779a57364 100644 --- a/docs/whats_new.md +++ b/docs/whats_new.md @@ -323,9 +323,9 @@ The Session middleware has undergone key changes in v3 to improve functionality - **Manual Session Release**: Session instances are no longer automatically released after being saved. To ensure proper lifecycle management, you must manually call `sess.Release()`. -- **Idle Timeout**: The `Expiration` field has been replaced with `IdleTimeout`, which strictly handles session inactivity. If you require a maximum session duration, you'll need to implement it within your own session data. +- **Idle Timeout**: The `Expiration` field has been replaced with `IdleTimeout`, which handles session inactivity. If the session is idle for the specified duration, it will expire. The idle timeout is updated when the session is saved. If you are using the middleware handler, the idle timeout will be updated automatically. -- **Absolute Timeout**: The `AbsoluteTimeout` field has been added. If you need to set an absolute session timeout, you can use this field to define the duration. +- **Absolute Timeout**: The `AbsoluteTimeout` field has been added. If you need to set an absolute session timeout, you can use this field to define the duration. The session will expire after the specified duration, regardless of activity. For more details on these changes and migration instructions, check the [Session Middleware Migration Guide](./middleware/session.md#migration-guide). diff --git a/middleware/session/data_test.go b/middleware/session/data_test.go index 7b8b0787b5..1913f761d3 100644 --- a/middleware/session/data_test.go +++ b/middleware/session/data_test.go @@ -121,7 +121,7 @@ func TestData_Len(t *testing.T) { d.Set("key2", "value2") d.Set("key3", "value3") - done := make(chan bool) + done := make(chan bool, 2) // Buffered channel with size 2 go func() { length := d.Len() assert.Equal(t, 3, length, "Expected length to be 3 during concurrent access") diff --git a/middleware/session/middleware.go b/middleware/session/middleware.go index 4d38800451..43ed7a3501 100644 --- a/middleware/session/middleware.go +++ b/middleware/session/middleware.go @@ -7,7 +7,6 @@ import ( "sync" "github.com/gofiber/fiber/v3" - "github.com/gofiber/fiber/v3/log" ) // Middleware holds session data and configuration. @@ -179,8 +178,6 @@ func releaseMiddleware(m *Middleware) { func FromContext(c fiber.Ctx) *Middleware { m, ok := c.Locals(middlewareContextKey).(*Middleware) if !ok { - // TODO: since this may be called we may not want to log this except in debug mode? - log.Warn("session: Session middleware not registered. See https://docs.gofiber.io/middleware/session") return nil } return m From 6f35ff847153dc5651758b5a59c1970e2db43017 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Tue, 1 Oct 2024 11:50:06 -0300 Subject: [PATCH 66/79] docs(middlware/session): fix indent lint --- docs/middleware/session.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/middleware/session.md b/docs/middleware/session.md index e0ec04f968..26d350d2f0 100644 --- a/docs/middleware/session.md +++ b/docs/middleware/session.md @@ -41,9 +41,9 @@ As of v3, we recommend using the middleware handler for session management. Howe - **Expiration Handling**: Previously, the `Expiration` field represented the maximum session duration before expiration. However, it would extend every time the session was saved, making its behavior a mix between session duration and session idle timeout. The `Expiration` field has been removed and replaced with `IdleTimeout` and `AbsoluteTimeout` fields, which explicitly defines the session's idle and absolute timeout periods. - - **Idle Timeout**: The new `IdleTimeout`, handles session inactivity. If the session is idle for the specified duration, it will expire. The idle timeout is updated when the session is saved. If you are using the middleware handler, the idle timeout will be updated automatically. + - **Idle Timeout**: The new `IdleTimeout`, handles session inactivity. If the session is idle for the specified duration, it will expire. The idle timeout is updated when the session is saved. If you are using the middleware handler, the idle timeout will be updated automatically. - - **Absolute Timeout**: The `AbsoluteTimeout` field has been added. If you need to set an absolute session timeout, you can use this field to define the duration. The session will expire after the specified duration, regardless of activity. + - **Absolute Timeout**: The `AbsoluteTimeout` field has been added. If you need to set an absolute session timeout, you can use this field to define the duration. The session will expire after the specified duration, regardless of activity. For more details about Fiber v3, see [What’s New](https://github.com/gofiber/fiber/blob/main/docs/whats_new.md). From f3c4e8ed989612c54d5f86a57a3504b48a024b8c Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Tue, 1 Oct 2024 12:44:15 -0300 Subject: [PATCH 67/79] fix(middleware/session): Address EfeCtn Comments --- docs/middleware/session.md | 14 +++++++------- middleware/session/config.go | 6 +++--- middleware/session/config_test.go | 2 +- middleware/session/middleware.go | 4 ++-- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/docs/middleware/session.md b/docs/middleware/session.md index 26d350d2f0..9fe685a45c 100644 --- a/docs/middleware/session.md +++ b/docs/middleware/session.md @@ -93,7 +93,7 @@ app.Get("/", func(c *fiber.Ctx) error { ```go store := session.NewStore() -app.Get("/", func(c *fiber.Ctx) error { +app.Get("/", func(c fiber.Ctx) error { sess, err := store.Get(c) if err != nil { return err @@ -131,9 +131,9 @@ Defines the configuration options for the session middleware. ```go type Config struct { Storage fiber.Storage - Next func(c *fiber.Ctx) bool + Next func(fiber.Ctx) bool Store *Store - ErrorHandler func(*fiber.Ctx, error) + ErrorHandler func(fiber.Ctx, error) KeyGenerator func() string KeyLookup string CookieDomain string @@ -188,7 +188,7 @@ func FromContext(c fiber.Ctx) *Middleware ### Config Methods ```go -func DefaultErrorHandler(c *fiber.Ctx, err error) +func DefaultErrorHandler(fiber.Ctx, err error) ``` ### Middleware Methods @@ -259,7 +259,7 @@ func main() { Store: sessionStore, })) - app.Get("/", func(c *fiber.Ctx) error { + app.Get("/", func(c fiber.Ctx) error { sess := session.FromContext(c) if sess == nil { return c.SendStatus(fiber.StatusInternalServerError) @@ -326,7 +326,7 @@ func main() { Store: sessionStore, })) - app.Get("/", func(c *fiber.Ctx) error { + app.Get("/", func(c fiber.Ctx) error { sess, err := sessionStore.Get(c) if err != nil { return c.SendStatus(fiber.StatusInternalServerError) @@ -341,7 +341,7 @@ func main() { return c.SendString("Welcome " + name) }) - app.Post("/login", func(c *fiber.Ctx) error { + app.Post("/login", func(c fiber.Ctx) error { sess, err := sessionStore.Get(c) if err != nil { return c.SendStatus(fiber.StatusInternalServerError) diff --git a/middleware/session/config.go b/middleware/session/config.go index 2686158e44..21513d79cb 100644 --- a/middleware/session/config.go +++ b/middleware/session/config.go @@ -28,7 +28,7 @@ type Config struct { // ErrorHandler defines a function to handle errors. // // Optional. Default: nil - ErrorHandler func(*fiber.Ctx, error) + ErrorHandler func(fiber.Ctx, error) // KeyGenerator generates the session key. // @@ -123,10 +123,10 @@ var ConfigDefault = Config{ // Usage: // // DefaultErrorHandler(c, err) -func DefaultErrorHandler(c *fiber.Ctx, err error) { +func DefaultErrorHandler(c fiber.Ctx, err error) { log.Errorf("session: %v", err) if c != nil { - if sendErr := (*c).SendStatus(fiber.StatusInternalServerError); sendErr != nil { + if sendErr := (c).SendStatus(fiber.StatusInternalServerError); sendErr != nil { log.Errorf("session: %v", sendErr) } } diff --git a/middleware/session/config_test.go b/middleware/session/config_test.go index 80d04f9750..c87ecef258 100644 --- a/middleware/session/config_test.go +++ b/middleware/session/config_test.go @@ -42,7 +42,7 @@ func TestDefaultErrorHandler(t *testing.T) { ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) // Test DefaultErrorHandler - DefaultErrorHandler(&ctx, fiber.ErrInternalServerError) + DefaultErrorHandler(ctx, fiber.ErrInternalServerError) require.Equal(t, fiber.StatusInternalServerError, ctx.Response().StatusCode()) } diff --git a/middleware/session/middleware.go b/middleware/session/middleware.go index 43ed7a3501..c14bc19efe 100644 --- a/middleware/session/middleware.go +++ b/middleware/session/middleware.go @@ -12,7 +12,7 @@ import ( // Middleware holds session data and configuration. type Middleware struct { Session *Session - ctx *fiber.Ctx + ctx fiber.Ctx config Config mu sync.RWMutex destroyed bool @@ -119,7 +119,7 @@ func (m *Middleware) initialize(c fiber.Ctx, cfg Config) { m.config = cfg m.Session = session - m.ctx = &c + m.ctx = c c.Locals(middlewareContextKey, m) } From e41ee7460b0aaf0f5220fe1d82300eb08fa217c1 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Wed, 2 Oct 2024 14:19:28 -0300 Subject: [PATCH 68/79] refactor(middleware/session): Move bytesBuffer to it's own pool --- middleware/session/session.go | 62 ++++++++++++++++++++---------- middleware/session/session_test.go | 29 ++++++++++++++ 2 files changed, 71 insertions(+), 20 deletions(-) diff --git a/middleware/session/session.go b/middleware/session/session.go index 35d05ec306..e2e3bee5d9 100644 --- a/middleware/session/session.go +++ b/middleware/session/session.go @@ -17,7 +17,6 @@ type Session struct { ctx fiber.Ctx // fiber context config *Store // store configuration data *data // key value data - byteBuffer *bytes.Buffer // byte buffer for encoding/decoding id string // session id idleTimeout time.Duration // idleTimeout of this session mu sync.RWMutex // Mutex to protect non-data fields @@ -31,11 +30,16 @@ const ( absExpirationKey absExpirationKeyType = iota ) +// Session pool for reusing byte buffers. +var byteBufferPool = sync.Pool{ + New: func() any { + return new(bytes.Buffer) + }, +} + var sessionPool = sync.Pool{ New: func() any { - return &Session{ - byteBuffer: new(bytes.Buffer), - } + return &Session{} }, } @@ -89,9 +93,6 @@ func releaseSession(s *Session) { if s.data != nil { s.data.Reset() } - if s.byteBuffer != nil { - s.byteBuffer.Reset() - } s.mu.Unlock() sessionPool.Put(s) } @@ -242,10 +243,6 @@ func (s *Session) Reset() error { s.mu.Lock() defer s.mu.Unlock() - // Reset byte buffer - if s.byteBuffer != nil { - s.byteBuffer.Reset() - } // Reset expiration s.idleTimeout = 0 @@ -310,18 +307,13 @@ func (s *Session) saveSession() error { s.setSession() // Encode session data - encCache := gob.NewEncoder(s.byteBuffer) s.data.RLock() - err := encCache.Encode(&s.data.Data) + encodedBytes, err := s.encodeSessionData() s.data.RUnlock() if err != nil { return fmt.Errorf("failed to encode data: %w", err) } - // Copy the data in buffer - encodedBytes := make([]byte, s.byteBuffer.Len()) - copy(encodedBytes, s.byteBuffer.Bytes()) - // Pass copied bytes with session id to provider return s.config.Storage.Set(s.id, encodedBytes, s.idleTimeout) } @@ -430,14 +422,44 @@ func (s *Session) delSession() { // // err := s.decodeSessionData(rawData) func (s *Session) decodeSessionData(rawData []byte) error { - _, _ = s.byteBuffer.Write(rawData) - encCache := gob.NewDecoder(s.byteBuffer) - if err := encCache.Decode(&s.data.Data); err != nil { + byteBuffer := byteBufferPool.Get().(*bytes.Buffer) //nolint:forcetypeassert,errcheck // We store nothing else in the pool + defer byteBufferPool.Put(byteBuffer) + defer byteBuffer.Reset() + _, _ = byteBuffer.Write(rawData) + decCache := gob.NewDecoder(byteBuffer) + if err := decCache.Decode(&s.data.Data); err != nil { return fmt.Errorf("failed to decode session data: %w", err) } return nil } +// encodeSessionData encodes session data to raw bytes +// +// Parameters: +// - rawData: The raw byte data to encode. +// +// Returns: +// - error: An error if the encoding fails. +// +// Usage: +// +// err := s.encodeSessionData(rawData) +func (s *Session) encodeSessionData() ([]byte, error) { + byteBuffer := byteBufferPool.Get().(*bytes.Buffer) //nolint:forcetypeassert,errcheck // We store nothing else in the pool + defer byteBufferPool.Put(byteBuffer) + defer byteBuffer.Reset() + encCache := gob.NewEncoder(byteBuffer) + if err := encCache.Encode(&s.data.Data); err != nil { + return nil, fmt.Errorf("failed to encode session data: %w", err) + } + // Copy the bytes + // Copy the data in buffer + encodedBytes := make([]byte, byteBuffer.Len()) + copy(encodedBytes, byteBuffer.Bytes()) + + return encodedBytes, nil +} + // absExpiration returns the session absolute expiration time or a zero time if not set. // // Returns: diff --git a/middleware/session/session_test.go b/middleware/session/session_test.go index f604d3b2dd..12a412a01b 100644 --- a/middleware/session/session_test.go +++ b/middleware/session/session_test.go @@ -1258,3 +1258,32 @@ func Test_Session_Concurrency(t *testing.T) { require.NoError(t, err) } } + +// func TestStore_Get_DecodeSessionDataError(t *testing.T) { +// // Initialize a new store with default config +// store := NewStore() + +// // Create a new Fiber app +// app := fiber.New() + +// // Generate a fake session ID +// sessionID := uuid.New().String() + +// // Store invalid session data to simulate decode error +// err := store.Storage.Set(sessionID, []byte("invalid data"), 0) +// require.NoError(t, err, "Failed to set invalid session data") + +// // Create a new request context +// c := app.AcquireCtx(&fasthttp.RequestCtx{}) +// defer app.ReleaseCtx(c) + +// // Set the session ID in cookies +// c.Request().Header.SetCookie(store.sessionName, sessionID) + +// // Attempt to get the session +// _, err = store.Get(c) +// require.Error(t, err, "Expected error due to invalid session data, but got nil") + +// // Check that the error message is as expected +// require.Contains(t, err.Error(), "failed to decode session data", "Unexpected error message") +// } From 07092c83cbf46052c0ae0bd35f955238f57edf45 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Wed, 2 Oct 2024 15:03:47 -0300 Subject: [PATCH 69/79] test(middleware/session): add decodeSessionData error coverage --- middleware/session/session_test.go | 43 +++++++++++++++--------------- 1 file changed, 22 insertions(+), 21 deletions(-) diff --git a/middleware/session/session_test.go b/middleware/session/session_test.go index 12a412a01b..02beb96c5c 100644 --- a/middleware/session/session_test.go +++ b/middleware/session/session_test.go @@ -8,6 +8,7 @@ import ( "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/internal/storage/memory" + "github.com/google/uuid" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" ) @@ -1259,31 +1260,31 @@ func Test_Session_Concurrency(t *testing.T) { } } -// func TestStore_Get_DecodeSessionDataError(t *testing.T) { -// // Initialize a new store with default config -// store := NewStore() +func Test_Session_StoreGetDecodeSessionDataError(t *testing.T) { + // Initialize a new store with default config + store := NewStore() -// // Create a new Fiber app -// app := fiber.New() + // Create a new Fiber app + app := fiber.New() -// // Generate a fake session ID -// sessionID := uuid.New().String() + // Generate a fake session ID + sessionID := uuid.New().String() -// // Store invalid session data to simulate decode error -// err := store.Storage.Set(sessionID, []byte("invalid data"), 0) -// require.NoError(t, err, "Failed to set invalid session data") + // Store invalid session data to simulate decode error + err := store.Storage.Set(sessionID, []byte("invalid data"), 0) + require.NoError(t, err, "Failed to set invalid session data") -// // Create a new request context -// c := app.AcquireCtx(&fasthttp.RequestCtx{}) -// defer app.ReleaseCtx(c) + // Create a new request context + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(c) -// // Set the session ID in cookies -// c.Request().Header.SetCookie(store.sessionName, sessionID) + // Set the session ID in cookies + c.Request().Header.SetCookie(store.sessionName, sessionID) -// // Attempt to get the session -// _, err = store.Get(c) -// require.Error(t, err, "Expected error due to invalid session data, but got nil") + // Attempt to get the session + _, err = store.Get(c) + require.Error(t, err, "Expected error due to invalid session data, but got nil") -// // Check that the error message is as expected -// require.Contains(t, err.Error(), "failed to decode session data", "Unexpected error message") -// } + // Check that the error message is as expected + require.Contains(t, err.Error(), "failed to decode session data", "Unexpected error message") +} From 84adbe1a678099ecdaa0768d99ba3bad1c07ad06 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Wed, 2 Oct 2024 16:35:49 -0300 Subject: [PATCH 70/79] refactor(middleware/session): Update absolute timeout handling - Update absolute timeout handling in getSession function - Set absolute expiration time in getSession function - Delete expired session in GetByID function --- middleware/session/session_test.go | 39 ++++++++++++++++++++++++++---- middleware/session/store.go | 7 ++++-- 2 files changed, 39 insertions(+), 7 deletions(-) diff --git a/middleware/session/session_test.go b/middleware/session/session_test.go index 02beb96c5c..03fc305285 100644 --- a/middleware/session/session_test.go +++ b/middleware/session/session_test.go @@ -621,12 +621,16 @@ func Test_Session_Save_AbsoluteTimeout(t *testing.T) { t.Run("save to cookie", func(t *testing.T) { t.Parallel() - const absoluteTimeout = 5 * time.Second + const absoluteTimeout = 1 * time.Second // session store store := NewStore(Config{ - IdleTimeout: 5 * time.Second, + IdleTimeout: absoluteTimeout, AbsoluteTimeout: absoluteTimeout, }) + + // force change to IdleTimeout + store.Config.IdleTimeout = 10 * time.Second + // fiber instance app := fiber.New() // fiber context @@ -657,21 +661,35 @@ func Test_Session_Save_AbsoluteTimeout(t *testing.T) { require.Equal(t, "john", sess.Get("name")) // just to make sure the session has been expired - time.Sleep(absoluteTimeout + (10 * time.Millisecond)) + time.Sleep(absoluteTimeout + (100 * time.Millisecond)) sess.Release() app.ReleaseCtx(ctx) ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) - defer app.ReleaseCtx(ctx) // here you should get a new session ctx.Request().Header.SetCookie(store.sessionName, token) sess, err = store.Get(ctx) - defer sess.Release() require.NoError(t, err) require.Nil(t, sess.Get("name")) require.NotEqual(t, sess.ID(), token) + require.True(t, sess.Fresh()) + require.IsType(t, time.Time{}, sess.Get(absExpirationKey)) + + token = sess.ID() + + sess.Set("name", "john") + + // save session + err = sess.Save() + require.NoError(t, err) + + sess.Release() + app.ReleaseCtx(ctx) + + // just to make sure the session has been expired + time.Sleep(absoluteTimeout + (100 * time.Millisecond)) // try to get expired session by id sess, err = store.GetByID(token) @@ -1287,4 +1305,15 @@ func Test_Session_StoreGetDecodeSessionDataError(t *testing.T) { // Check that the error message is as expected require.Contains(t, err.Error(), "failed to decode session data", "Unexpected error message") + + // Check that the error is as expected + require.ErrorContains(t, err, "failed to decode session data", "Unexpected error") + + // Attempt to get the session by ID + _, err = store.GetByID(sessionID) + require.Error(t, err, "Expected error due to invalid session data, but got nil") + + // Check that the error message is as expected + require.ErrorContains(t, err, "failed to decode session data", "Unexpected error") + } diff --git a/middleware/session/store.go b/middleware/session/store.go index 4cfcc998bd..888e39727f 100644 --- a/middleware/session/store.go +++ b/middleware/session/store.go @@ -177,6 +177,7 @@ func (s *Store) getSession(c fiber.Ctx) (*Session, error) { if err := sess.Reset(); err != nil { return nil, fmt.Errorf("failed to reset session: %w", err) } + sess.setAbsExpiration(time.Now().Add(s.AbsoluteTimeout)) } return sess, nil @@ -310,8 +311,10 @@ func (s *Store) GetByID(id string) (*Session, error) { if s.AbsoluteTimeout > 0 { if sess.isAbsExpired() { - if err := sess.Destroy(); err != nil { - log.Errorf("failed to destroy expired session: %v", err) + err := sess.config.Storage.Delete(sess.ID()) + sess.Release() + if err != nil { + log.Errorf("failed to delete expired session: %v", err) } return nil, ErrSessionIDNotFoundInStore } From f6440e25afa6d636a863a6692c169ec6b59842e0 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Wed, 2 Oct 2024 17:05:29 -0300 Subject: [PATCH 71/79] refactor(session/middleware): fix *Session nil ctx when using Store.GetByID --- middleware/session/session.go | 12 ++++++++++++ middleware/session/store.go | 12 +++++------- middleware/session/store_test.go | 19 +++++++++++++++++++ 3 files changed, 36 insertions(+), 7 deletions(-) diff --git a/middleware/session/session.go b/middleware/session/session.go index e2e3bee5d9..ffb5c52722 100644 --- a/middleware/session/session.go +++ b/middleware/session/session.go @@ -278,6 +278,10 @@ func (s *Session) refresh() { // // err := s.Save() func (s *Session) Save() error { + if s.ctx == nil { + return s.saveSession() + } + // If the session is being used in the handler, it should not be saved if m, ok := s.ctx.Locals(middlewareContextKey).(*Middleware); ok { if m.Session == s { @@ -348,6 +352,10 @@ func (s *Session) SetIdleTimeout(idleTimeout time.Duration) { } func (s *Session) setSession() { + if s.ctx == nil { + return + } + if s.config.source == SourceHeader { s.ctx.Request().Header.SetBytesV(s.config.sessionName, []byte(s.id)) s.ctx.Response().Header.SetBytesV(s.config.sessionName, []byte(s.id)) @@ -380,6 +388,10 @@ func (s *Session) setSession() { } func (s *Session) delSession() { + if s.ctx == nil { + return + } + if s.config.source == SourceHeader { s.ctx.Request().Header.Del(s.config.sessionName) s.ctx.Response().Header.Del(s.config.sessionName) diff --git a/middleware/session/store.go b/middleware/session/store.go index 888e39727f..0311e79ea2 100644 --- a/middleware/session/store.go +++ b/middleware/session/store.go @@ -297,24 +297,22 @@ func (s *Store) GetByID(id string) (*Session, error) { sess.mu.Lock() - sess.id = id sess.config = s + sess.id = id + sess.fresh = false sess.data.Lock() decodeErr := sess.decodeSessionData(rawData) sess.data.Unlock() + sess.mu.Unlock() if decodeErr != nil { - sess.mu.Unlock() return nil, fmt.Errorf("failed to decode session data: %w", err) } - sess.mu.Unlock() if s.AbsoluteTimeout > 0 { if sess.isAbsExpired() { - err := sess.config.Storage.Delete(sess.ID()) - sess.Release() - if err != nil { - log.Errorf("failed to delete expired session: %v", err) + if err := sess.Destroy(); err != nil { + log.Errorf("failed to destroy session: %v", err) } return nil, ErrSessionIDNotFoundInStore } diff --git a/middleware/session/store_test.go b/middleware/session/store_test.go index 3d2395e2fb..bff338c35e 100644 --- a/middleware/session/store_test.go +++ b/middleware/session/store_test.go @@ -181,6 +181,7 @@ func Test_Store_GetByID(t *testing.T) { // Create a new session ctx := fiber.New().AcquireCtx(&fasthttp.RequestCtx{}) session, err := store.Get(ctx) + defer session.Release() require.NoError(t, err) // Save the session ID @@ -192,8 +193,26 @@ func Test_Store_GetByID(t *testing.T) { // Retrieve the session by ID retrievedSession, err := store.GetByID(sessionID) + defer retrievedSession.Release() require.NoError(t, err) require.NotNil(t, retrievedSession) require.Equal(t, sessionID, retrievedSession.ID()) + + // Call Save on the retrieved session + retrievedSession.Set("key", "value") + err = retrievedSession.Save() + require.NoError(t, err) + + // Call Other Session methods + require.Equal(t, "value", retrievedSession.Get("key")) + require.False(t, retrievedSession.Fresh()) + + require.NoError(t, retrievedSession.Reset()) + require.NoError(t, retrievedSession.Destroy()) + require.IsType(t, []any{}, retrievedSession.Keys()) + require.NoError(t, retrievedSession.Regenerate()) + require.NotPanics(t, func() { + retrievedSession.Release() + }) }) } From eac16b6af638a3c564a23168e4b569afa8c9728f Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Wed, 2 Oct 2024 17:07:59 -0300 Subject: [PATCH 72/79] refactor(middleware/session): Remove unnecessary line in session_test.go --- middleware/session/session_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/middleware/session/session_test.go b/middleware/session/session_test.go index 03fc305285..038bfc4b8d 100644 --- a/middleware/session/session_test.go +++ b/middleware/session/session_test.go @@ -1315,5 +1315,4 @@ func Test_Session_StoreGetDecodeSessionDataError(t *testing.T) { // Check that the error message is as expected require.ErrorContains(t, err, "failed to decode session data", "Unexpected error") - } From 7068a0e91e7c1fe527b4503f87833a37c20d2444 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Wed, 2 Oct 2024 17:22:50 -0300 Subject: [PATCH 73/79] fix(middleware/session): *Session lifecycle issues --- middleware/session/store.go | 5 ++++- middleware/session/store_test.go | 1 - 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/middleware/session/store.go b/middleware/session/store.go index 0311e79ea2..800e38cdb2 100644 --- a/middleware/session/store.go +++ b/middleware/session/store.go @@ -165,6 +165,7 @@ func (s *Store) getSession(c fiber.Ctx) (*Session, error) { sess.data.Unlock() if err != nil { sess.mu.Unlock() + sess.Release() return nil, fmt.Errorf("failed to decode session data: %w", err) } } @@ -306,12 +307,14 @@ func (s *Store) GetByID(id string) (*Session, error) { sess.data.Unlock() sess.mu.Unlock() if decodeErr != nil { - return nil, fmt.Errorf("failed to decode session data: %w", err) + sess.Release() + return nil, fmt.Errorf("failed to decode session data: %w", decodeErr) } if s.AbsoluteTimeout > 0 { if sess.isAbsExpired() { if err := sess.Destroy(); err != nil { + sess.Release() log.Errorf("failed to destroy session: %v", err) } return nil, ErrSessionIDNotFoundInStore diff --git a/middleware/session/store_test.go b/middleware/session/store_test.go index bff338c35e..71a10dc941 100644 --- a/middleware/session/store_test.go +++ b/middleware/session/store_test.go @@ -193,7 +193,6 @@ func Test_Store_GetByID(t *testing.T) { // Retrieve the session by ID retrievedSession, err := store.GetByID(sessionID) - defer retrievedSession.Release() require.NoError(t, err) require.NotNil(t, retrievedSession) require.Equal(t, sessionID, retrievedSession.ID()) From 87a6cb90c0c4d4c1de051efbc9956e4be4530a35 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Wed, 2 Oct 2024 17:33:43 -0300 Subject: [PATCH 74/79] docs(middleware/session): Update GetByID method documentation --- middleware/session/store.go | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/middleware/session/store.go b/middleware/session/store.go index 800e38cdb2..013743d068 100644 --- a/middleware/session/store.go +++ b/middleware/session/store.go @@ -257,15 +257,21 @@ func (s *Store) Delete(id string) error { // GetByID retrieves a session by its ID from the storage. // If the session is not found, it returns nil and an error. // -// Note: -// - Unlike session Middleware methods, Session methods do not automatically: -// - Load the session into the context -// - Save the session data to the storage and update the client cookie +// Unlike session middleware methods, this function does not automatically: // -// - Be aware of possible collisions if you are also using the session in a middleware. +// - Load the session into the request context. +// +// - Save the session data to the storage or update the client cookie. +// +// Important Notes: +// +// - The session object returned by GetByID does not have a context associated with it. +// +// - When using this method alongside session middleware, there is a potential for collisions, +// so be mindful of interactions between manually retrieved sessions and middleware-managed sessions. // -// Usage: // - If you modify a session returned by GetByID, you must call session.Save() to persist the changes. +// // - When you are done with the session, you should call session.Release() to release the session back to the pool. // // Parameters: From e5e5fd84d5a92cd3ee53979168d05790d2865c1e Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Wed, 2 Oct 2024 17:46:09 -0300 Subject: [PATCH 75/79] docs(middleware/session): Update GetByID method documentation --- docs/middleware/session.md | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/docs/middleware/session.md b/docs/middleware/session.md index 9fe685a45c..808f9ee01e 100644 --- a/docs/middleware/session.md +++ b/docs/middleware/session.md @@ -228,6 +228,25 @@ func (s *Store) Reset() error func (s *Store) Delete(id string) error ``` +:::note +#### `GetByID` Method + +The `GetByID` method retrieves a session from storage using its session ID. Unlike `Get`, which ties the session to a `fiber.Ctx` (request-response cycle), `GetByID` operates independently of any HTTP context. This makes it ideal for scenarios such as background processing, scheduled tasks, or non-HTTP-related session management. + +##### Key Features: +- **Context Independence**: Sessions retrieved via `GetByID` are not bound to `fiber.Ctx`. This means the session can be manipulated in contexts that aren't tied to an active HTTP request-response cycle. +- **Background Task Suitability**: Use this method when you need to manage sessions outside of the standard HTTP workflow, such as in scheduled jobs, background tasks, or any non-HTTP context where session data needs to be accessed or modified. + +##### Usage Considerations: +- **Manual Persistence**: Since there is no associated `fiber.Ctx`, changes made to the session (e.g., modifying data) will **not** automatically be saved to storage. You **must** call `session.Save()` explicitly to persist any updates to storage. +- **No Automatic Cookie Handling**: Any updates made to the session will **not** affect the client-side cookies. If the session changes need to be reflected in the client (e.g., in a future HTTP response), you will need to handle this manually by setting the cookies via other methods. +- **Resource Management**: After using a session retrieved by `GetByID`, you should call `session.Release()` to properly release the session back to the pool and free up resources. + +##### Example Use Cases: +- **Scheduled Jobs**: Retrieve and update session data periodically without triggering an HTTP request. +- **Background Processing**: Manage sessions for tasks running in the background, such as user inactivity checks or batch processing. +::: + ## Examples :::note From 00b9e07bb2e899f99978e1dc23406d6df04044d2 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Wed, 2 Oct 2024 17:52:06 -0300 Subject: [PATCH 76/79] docs(middleware/session): markdown lint --- docs/middleware/session.md | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/docs/middleware/session.md b/docs/middleware/session.md index 808f9ee01e..ff73ff6094 100644 --- a/docs/middleware/session.md +++ b/docs/middleware/session.md @@ -229,22 +229,27 @@ func (s *Store) Delete(id string) error ``` :::note + #### `GetByID` Method The `GetByID` method retrieves a session from storage using its session ID. Unlike `Get`, which ties the session to a `fiber.Ctx` (request-response cycle), `GetByID` operates independently of any HTTP context. This makes it ideal for scenarios such as background processing, scheduled tasks, or non-HTTP-related session management. -##### Key Features: +##### Key Features + - **Context Independence**: Sessions retrieved via `GetByID` are not bound to `fiber.Ctx`. This means the session can be manipulated in contexts that aren't tied to an active HTTP request-response cycle. - **Background Task Suitability**: Use this method when you need to manage sessions outside of the standard HTTP workflow, such as in scheduled jobs, background tasks, or any non-HTTP context where session data needs to be accessed or modified. -##### Usage Considerations: +##### Usage Considerations + - **Manual Persistence**: Since there is no associated `fiber.Ctx`, changes made to the session (e.g., modifying data) will **not** automatically be saved to storage. You **must** call `session.Save()` explicitly to persist any updates to storage. - **No Automatic Cookie Handling**: Any updates made to the session will **not** affect the client-side cookies. If the session changes need to be reflected in the client (e.g., in a future HTTP response), you will need to handle this manually by setting the cookies via other methods. - **Resource Management**: After using a session retrieved by `GetByID`, you should call `session.Release()` to properly release the session back to the pool and free up resources. -##### Example Use Cases: +##### Example Use Cases + - **Scheduled Jobs**: Retrieve and update session data periodically without triggering an HTTP request. - **Background Processing**: Manage sessions for tasks running in the background, such as user inactivity checks or batch processing. + ::: ## Examples From 23e823b9b9436302b594ba783e9f5e1d670798bc Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Wed, 2 Oct 2024 20:00:44 -0300 Subject: [PATCH 77/79] refactor(middleware/session): Simplify error handling in DefaultErrorHandler --- middleware/session/config.go | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/middleware/session/config.go b/middleware/session/config.go index 21513d79cb..a1acd2656c 100644 --- a/middleware/session/config.go +++ b/middleware/session/config.go @@ -125,10 +125,8 @@ var ConfigDefault = Config{ // DefaultErrorHandler(c, err) func DefaultErrorHandler(c fiber.Ctx, err error) { log.Errorf("session: %v", err) - if c != nil { - if sendErr := (c).SendStatus(fiber.StatusInternalServerError); sendErr != nil { - log.Errorf("session: %v", sendErr) - } + if sendErr := (c).SendStatus(fiber.StatusInternalServerError); sendErr != nil { + log.Errorf("session: %v", sendErr) } } From ba387862d6d3f540f9bcc5abe5402b8b434649f0 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Wed, 2 Oct 2024 21:08:25 -0300 Subject: [PATCH 78/79] fix( middleware/session/config.go Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- middleware/session/config.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/middleware/session/config.go b/middleware/session/config.go index a1acd2656c..c2a115d732 100644 --- a/middleware/session/config.go +++ b/middleware/session/config.go @@ -125,7 +125,7 @@ var ConfigDefault = Config{ // DefaultErrorHandler(c, err) func DefaultErrorHandler(c fiber.Ctx, err error) { log.Errorf("session: %v", err) - if sendErr := (c).SendStatus(fiber.StatusInternalServerError); sendErr != nil { + if sendErr := c.SendStatus(fiber.StatusInternalServerError); sendErr != nil { log.Errorf("session: %v", sendErr) } } From b54c954f5485c93f3e7ded5b3eafce589741b4c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9?= Date: Wed, 23 Oct 2024 12:05:51 +0200 Subject: [PATCH 79/79] add ctx releases for the test cases --- middleware/session/store_test.go | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/middleware/session/store_test.go b/middleware/session/store_test.go index 71a10dc941..8a45c7e5fb 100644 --- a/middleware/session/store_test.go +++ b/middleware/session/store_test.go @@ -23,6 +23,7 @@ func Test_Store_getSessionID(t *testing.T) { store := NewStore() // fiber context ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(ctx) // set cookie ctx.Request().Header.SetCookie(store.sessionName, expectedID) @@ -38,6 +39,7 @@ func Test_Store_getSessionID(t *testing.T) { }) // fiber context ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(ctx) // set header ctx.Request().Header.Set(store.sessionName, expectedID) @@ -53,6 +55,7 @@ func Test_Store_getSessionID(t *testing.T) { }) // fiber context ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(ctx) // set url parameter ctx.Request().SetRequestURI(fmt.Sprintf("/path?%s=%s", store.sessionName, expectedID)) @@ -76,6 +79,7 @@ func Test_Store_Get(t *testing.T) { store := NewStore() // fiber context ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(ctx) // set cookie ctx.Request().Header.SetCookie(store.sessionName, unexpectedID) @@ -97,6 +101,7 @@ func Test_Store_DeleteSession(t *testing.T) { // fiber context ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(ctx) // Create a new session session, err := store.Get(ctx) @@ -123,6 +128,7 @@ func TestStore_Get_SessionAlreadyLoaded(t *testing.T) { // Create a new context ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(ctx) // Mock middleware and set it in the context middleware := &Middleware{} @@ -178,10 +184,12 @@ func Test_Store_GetByID(t *testing.T) { t.Run("valid session ID", func(t *testing.T) { t.Parallel() + app := fiber.New() // Create a new session - ctx := fiber.New().AcquireCtx(&fasthttp.RequestCtx{}) + ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) session, err := store.Get(ctx) defer session.Release() + defer app.ReleaseCtx(ctx) require.NoError(t, err) // Save the session ID