This repository has been archived by the owner on Dec 14, 2023. It is now read-only.
forked from boj/redistore
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #8 from getbread/cdoxsey/grow20
add new cookieless session store
- Loading branch information
Showing
5 changed files
with
138 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
package redistore | ||
|
||
import ( | ||
"crypto/hmac" | ||
"crypto/sha256" | ||
"encoding/hex" | ||
"errors" | ||
"fmt" | ||
"io" | ||
"net/http" | ||
"strings" | ||
) | ||
|
||
// errors | ||
var ( | ||
ErrHeaderNotFound = errors.New("header not found") | ||
ErrInvalidSessionID = errors.New("invalid session id") | ||
) | ||
|
||
// A CookielessSessionIDStore stores session ids in http headers | ||
type CookielessSessionIDStore struct { | ||
name string | ||
password string | ||
} | ||
|
||
// NewCookielessSessionIDStore creates a new CookielessSessionIDStore | ||
func NewCookielessSessionIDStore(name, password string) *CookielessSessionIDStore { | ||
return &CookielessSessionIDStore{ | ||
name: name, | ||
password: password, | ||
} | ||
} | ||
|
||
// Load attempts to load a session id from an http request | ||
func (store CookielessSessionIDStore) Load(r *http.Request) (string, error) { | ||
value := r.Header.Get(store.headerName()) | ||
if value == "" { | ||
return "", ErrHeaderNotFound | ||
} | ||
|
||
idx := strings.IndexByte(value, ':') | ||
if idx < 0 { | ||
return "", ErrInvalidSessionID | ||
} | ||
|
||
mac, sessionID := value[:idx], value[idx+1:] | ||
if !store.verify(mac, sessionID) { | ||
return "", ErrInvalidSessionID | ||
} | ||
|
||
return sessionID, nil | ||
} | ||
|
||
// Save attemps to save a session id to an http response writer | ||
func (store CookielessSessionIDStore) Save(sessionID string, w http.ResponseWriter) { | ||
mac := store.sign(sessionID) | ||
w.Header().Set(store.headerName(), fmt.Sprintf("%s:%s", mac, sessionID)) | ||
} | ||
|
||
func (store CookielessSessionIDStore) sign(message string) (mac string) { | ||
h := hmac.New(sha256.New, []byte(store.password)) | ||
io.WriteString(h, message) | ||
signature := h.Sum(nil) | ||
return hex.EncodeToString(signature) | ||
} | ||
|
||
func (store CookielessSessionIDStore) verify(mac, message string) bool { | ||
mac1, err := hex.DecodeString(mac) | ||
if err != nil { | ||
return false | ||
} | ||
|
||
mac2, err := hex.DecodeString(store.sign(message)) | ||
if err != nil { | ||
return false | ||
} | ||
|
||
return hmac.Equal(mac1, mac2) | ||
} | ||
|
||
func (store CookielessSessionIDStore) headerName() string { | ||
return fmt.Sprintf("X-SESSION-ID-%s", store.name) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
package redistore | ||
|
||
import ( | ||
"net/http" | ||
"net/http/httptest" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
func TestCookielessSessionIDStore(t *testing.T) { | ||
store := NewCookielessSessionIDStore("TEST", "ABCD") | ||
|
||
t.Run("load empty", func(t *testing.T) { | ||
req, _ := http.NewRequest("GET", "/", nil) | ||
sessionID, err := store.Load(req) | ||
assert.Empty(t, sessionID) | ||
assert.Equal(t, ErrHeaderNotFound, err) | ||
}) | ||
t.Run("load invalid", func(t *testing.T) { | ||
req, _ := http.NewRequest("GET", "/", nil) | ||
req.Header.Set("X-SESSION-ID-TEST", "XYZ:value") | ||
sessionID, err := store.Load(req) | ||
assert.Empty(t, sessionID) | ||
assert.Equal(t, ErrInvalidSessionID, err) | ||
}) | ||
t.Run("load valid", func(t *testing.T) { | ||
req, _ := http.NewRequest("GET", "/", nil) | ||
req.Header.Set("X-SESSION-ID-TEST", "94d5574a0ef464c629296fc9d263517944b94d1df9f3472fb7fb2d90af42ca36:value") | ||
sessionID, err := store.Load(req) | ||
assert.NotEmpty(t, sessionID) | ||
assert.NoError(t, err) | ||
}) | ||
|
||
t.Run("save", func(t *testing.T) { | ||
rec := httptest.NewRecorder() | ||
store.Save("value", rec) | ||
assert.Equal(t, "94d5574a0ef464c629296fc9d263517944b94d1df9f3472fb7fb2d90af42ca36:value", rec.Header().Get("X-SESSION-ID-TEST")) | ||
}) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters