Skip to content

Commit

Permalink
fix(middleware/session): mutex for thread safety (#3050)
Browse files Browse the repository at this point in the history
* chore: Remove extra release and acquire ctx calls in session_test.go

* feat: Remove unnecessary session mutex lock in decodeSessionData function

* chore: Refactor session benchmark tests

* fix(middleware/session): mutex for thread safety

* feat: Add session mutex lock for thread safety

* chore: Refactor releaseSession mutex
  • Loading branch information
sixcolors authored Jun 30, 2024
1 parent 6fa0e7c commit 66a8814
Show file tree
Hide file tree
Showing 3 changed files with 276 additions and 24 deletions.
49 changes: 41 additions & 8 deletions middleware/session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
)

type Session struct {
mu sync.RWMutex // Mutex to protect non-data fields
id string // session id
fresh bool // if new session
ctx *fiber.Ctx // fiber context
Expand Down Expand Up @@ -42,6 +43,7 @@ func acquireSession() *Session {
}

func releaseSession(s *Session) {
s.mu.Lock()
s.id = ""
s.exp = 0
s.ctx = nil
Expand All @@ -52,16 +54,21 @@ func releaseSession(s *Session) {
if s.byteBuffer != nil {
s.byteBuffer.Reset()
}
s.mu.Unlock()
sessionPool.Put(s)
}

// Fresh is true if the current session is new
func (s *Session) Fresh() bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.fresh
}

// ID returns the session id
func (s *Session) ID() string {
s.mu.RLock()
defer s.mu.RUnlock()
return s.id
}

Expand Down Expand Up @@ -102,6 +109,9 @@ func (s *Session) Destroy() error {
// Reset local data
s.data.Reset()

s.mu.RLock()
defer s.mu.RUnlock()

// Use external Storage if exist
if err := s.config.Storage.Delete(s.id); err != nil {
return err
Expand All @@ -114,6 +124,9 @@ func (s *Session) Destroy() error {

// Regenerate generates a new session id and delete the old one from Storage
func (s *Session) Regenerate() error {
s.mu.Lock()
defer s.mu.Unlock()

// Delete old id from storage
if err := s.config.Storage.Delete(s.id); err != nil {
return err
Expand All @@ -131,6 +144,10 @@ func (s *Session) Reset() error {
if s.data != nil {
s.data.Reset()
}

s.mu.Lock()
defer s.mu.Unlock()

// Reset byte buffer
if s.byteBuffer != nil {
s.byteBuffer.Reset()
Expand All @@ -154,20 +171,24 @@ func (s *Session) Reset() error {

// refresh generates a new session, and set session.fresh to be true
func (s *Session) refresh() {
// Create a new id
s.id = s.config.KeyGenerator()

// We assign a new id to the session, so the session must be fresh
s.fresh = true
}

// Save will update the storage and client cookie
//
// sess.Save() will save the session data to the storage and update the
// client cookie, and it will release the session after saving.
//
// It's not safe to use the session after calling Save().
func (s *Session) Save() error {
// Better safe than sorry
if s.data == nil {
return nil
}

s.mu.Lock()

// Check if session has your own expiration, otherwise use default value
if s.exp <= 0 {
s.exp = s.config.Expiration
Expand All @@ -177,25 +198,25 @@ func (s *Session) Save() error {
s.setSession()

// Convert data to bytes
mux.Lock()
defer mux.Unlock()
encCache := gob.NewEncoder(s.byteBuffer)
err := encCache.Encode(&s.data.Data)
if err != nil {
return fmt.Errorf("failed to encode data: %w", err)
}

// copy the data in buffer
// Copy the data in buffer
encodedBytes := make([]byte, s.byteBuffer.Len())
copy(encodedBytes, s.byteBuffer.Bytes())

// pass copied bytes with session id to provider
// Pass copied bytes with session id to provider
if err := s.config.Storage.Set(s.id, encodedBytes, s.exp); err != nil {
return err
}

s.mu.Unlock()

// Release session
// TODO: It's not safe to use the Session after called Save()
// TODO: It's not safe to use the Session after calling Save()
releaseSession(s)

return nil
Expand All @@ -211,6 +232,8 @@ func (s *Session) Keys() []string {

// SetExpiry sets a specific expiration for this session
func (s *Session) SetExpiry(exp time.Duration) {
s.mu.Lock()
defer s.mu.Unlock()
s.exp = exp
}

Expand Down Expand Up @@ -276,3 +299,13 @@ func (s *Session) delSession() {
fasthttp.ReleaseCookie(fcookie)
}
}

// decodeSessionData decodes the session data from raw bytes.
func (s *Session) decodeSessionData(rawData []byte) error {
_, _ = s.byteBuffer.Write(rawData) //nolint:errcheck // This will never fail
encCache := gob.NewDecoder(s.byteBuffer)
if err := encCache.Decode(&s.data.Data); err != nil {
return fmt.Errorf("failed to decode session data: %w", err)
}
return nil
}
229 changes: 229 additions & 0 deletions middleware/session/session_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package session

import (
"errors"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -673,3 +675,230 @@ func Benchmark_Session(b *testing.B) {
utils.AssertEqual(b, nil, err)
})
}

// go test -v -run=^$ -bench=Benchmark_Session_Parallel -benchmem -count=4
func Benchmark_Session_Parallel(b *testing.B) {
b.Run("default", func(b *testing.B) {
app, store := fiber.New(), New()
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().Header.SetCookie(store.sessionName, "12356789")

sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark
sess.Set("john", "doe")
_ = sess.Save() //nolint:errcheck // We're inside a benchmark
app.ReleaseCtx(c)
}
})
})

b.Run("storage", func(b *testing.B) {
app := fiber.New()
store := New(Config{
Storage: memory.New(),
})
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().Header.SetCookie(store.sessionName, "12356789")

sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark
sess.Set("john", "doe")
_ = sess.Save() //nolint:errcheck // We're inside a benchmark
app.ReleaseCtx(c)
}
})
})
}

// go test -v -run=^$ -bench=Benchmark_Session_Asserted -benchmem -count=4
func Benchmark_Session_Asserted(b *testing.B) {
b.Run("default", func(b *testing.B) {
app, store := fiber.New(), New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)
c.Request().Header.SetCookie(store.sessionName, "12356789")

b.ReportAllocs()
b.ResetTimer()
for n := 0; n < b.N; n++ {
sess, err := store.Get(c)
utils.AssertEqual(b, nil, err)
sess.Set("john", "doe")
err = sess.Save()
utils.AssertEqual(b, nil, err)
}
})

b.Run("storage", func(b *testing.B) {
app := fiber.New()
store := New(Config{
Storage: memory.New(),
})
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)
c.Request().Header.SetCookie(store.sessionName, "12356789")

b.ReportAllocs()
b.ResetTimer()
for n := 0; n < b.N; n++ {
sess, err := store.Get(c)
utils.AssertEqual(b, nil, err)
sess.Set("john", "doe")
err = sess.Save()
utils.AssertEqual(b, nil, err)
}
})
}

// go test -v -run=^$ -bench=Benchmark_Session_Asserted_Parallel -benchmem -count=4
func Benchmark_Session_Asserted_Parallel(b *testing.B) {
b.Run("default", func(b *testing.B) {
app, store := fiber.New(), New()
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().Header.SetCookie(store.sessionName, "12356789")

sess, err := store.Get(c)
utils.AssertEqual(b, nil, err)
sess.Set("john", "doe")
utils.AssertEqual(b, nil, sess.Save())
app.ReleaseCtx(c)
}
})
})

b.Run("storage", func(b *testing.B) {
app := fiber.New()
store := New(Config{
Storage: memory.New(),
})
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().Header.SetCookie(store.sessionName, "12356789")

sess, err := store.Get(c)
utils.AssertEqual(b, nil, err)
sess.Set("john", "doe")
utils.AssertEqual(b, nil, sess.Save())
app.ReleaseCtx(c)
}
})
})
}

// go test -v -race -run Test_Session_Concurrency ./...
func Test_Session_Concurrency(t *testing.T) {
t.Parallel()
app := fiber.New()
store := New()

var wg sync.WaitGroup
errChan := make(chan error, 10) // Buffered channel to collect errors
const numGoroutines = 10 // Number of concurrent goroutines to test

// Start numGoroutines goroutines
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()

localCtx := app.AcquireCtx(&fasthttp.RequestCtx{})

sess, err := store.Get(localCtx)
if err != nil {
errChan <- err
return
}

// Set a value
sess.Set("name", "john")

// get the session id
id := sess.ID()

// Check if the session is fresh
if !sess.Fresh() {
errChan <- errors.New("session should be fresh")
return
}

// Save the session
if err := sess.Save(); err != nil {
errChan <- err
return
}

// Release the context
app.ReleaseCtx(localCtx)

// Acquire a new context
localCtx = app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(localCtx)

// Set the session id in the header
localCtx.Request().Header.SetCookie(store.sessionName, id)

// Get the session
sess, err = store.Get(localCtx)
if err != nil {
errChan <- err
return
}

// Get the value
name := sess.Get("name")
if name != "john" {
errChan <- errors.New("name should be john")
return
}

// Get ID from the session
if sess.ID() != id {
errChan <- errors.New("id should be the same")
return
}

// Check if the session is fresh
if sess.Fresh() {
errChan <- errors.New("session should not be fresh")
return
}

// Delete the key
sess.Delete("name")

// Get the value
name = sess.Get("name")
if name != nil {
errChan <- errors.New("name should be nil")
return
}

// Destroy the session
if err := sess.Destroy(); err != nil {
errChan <- err
return
}
}()
}

wg.Wait() // Wait for all goroutines to finish
close(errChan) // Close the channel to signal no more errors will be sent

// Check for errors sent to errChan
for err := range errChan {
utils.AssertEqual(t, nil, err)
}
}
Loading

0 comments on commit 66a8814

Please sign in to comment.