Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Now store the session as JSON. #4829

Merged
merged 4 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions server/migrations/006.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ALTER TABLE persistent_session ADD COLUMN "session_json" "jsonb";
72 changes: 54 additions & 18 deletions server/src/Utopia/Web/Auth/Session.hs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
-}
module Utopia.Web.Auth.Session where

import Control.Lens
import Control.Monad
import qualified Data.HashMap.Strict as HM
import qualified Data.HashMap.Strict as M
import Data.List (lookup)
import Data.Profunctor.Product
import qualified Data.Serialize as S
Expand All @@ -35,6 +35,7 @@ import Web.Cookie
import Web.PathPieces (fromPathPiece, toPathPiece)
import Web.ServerSession.Core hiding (setCookieName)
import Web.ServerSession.Core.Internal (storage, unS)
import Data.Aeson

newtype ConnectionStorage sess = ConnectionStorage { pool :: DBPool }
deriving (Typeable)
Expand All @@ -43,10 +44,10 @@ type SessionStorage = ConnectionStorage SessionMap

type SessionState = State SessionStorage

type PersistentSessionFields = (Field SqlText, FieldNullable SqlBytea, Field SqlBytea, Field SqlTimestamptz, Field SqlTimestamptz)
type PersistentSessionFields = (Field SqlText, FieldNullable SqlBytea, Field SqlBytea, FieldNullable SqlJsonb, Field SqlTimestamptz, Field SqlTimestamptz)

persistentSessionTable :: Table PersistentSessionFields PersistentSessionFields
persistentSessionTable = table "persistent_session" (p5 (tableField "key", tableField "auth_id", tableField "session", tableField "created_at", tableField "accessed_at"))
persistentSessionTable = table "persistent_session" (p6 (tableField "key", tableField "auth_id", tableField "session", tableField "session_json", tableField "created_at", tableField "accessed_at"))

instance S.Serialize SessionMap where
put = S.put . map (first TE.encodeUtf8) . HM.toList . unSessionMap
Expand All @@ -55,8 +56,8 @@ instance S.Serialize SessionMap where
throwSS :: StorageException SessionStorage -> TransactionM SessionStorage a
throwSS = liftIO . throwIO

sessionFromTable :: (Text, Maybe ByteString, ByteString, UTCTime, UTCTime) -> IO (Session (SessionData SessionStorage))
sessionFromTable (key, authId, session, createdAt, accessedAt) = do
sessionFromTable :: (Text, Maybe ByteString, ByteString, Maybe Value, UTCTime, UTCTime) -> IO (Session (SessionData SessionStorage))
sessionFromTable (key, authId, session, _, createdAt, accessedAt) = do
parsedKey <- maybe (fail "Could not parse key.") pure $ fromPathPiece key
sessionMap <- either fail pure $ S.decode session :: IO SessionMap
pure $ Session
Expand All @@ -67,26 +68,41 @@ sessionFromTable (key, authId, session, createdAt, accessedAt) = do
, sessionAccessedAt = accessedAt
}

sessionToTable :: Session (SessionData SessionStorage) -> PersistentSessionFields
sessionToTable Session{..} =
data UserSession = UserSession { userID :: Text }
deriving (Eq, Ord, Show, Generic)

instance ToJSON UserSession where

sessionMapToJSON :: SessionMap -> IO Value
sessionMapToJSON sessionMap = do
let actualMap = unSessionMap sessionMap
let sessionKeys = HM.keys actualMap
unless (sessionKeys == ["user_id"]) $ do
fail ("Unexpected keys in session map: " <> show sessionKeys)
userID <- maybe (fail "Could not find user_id.") (\bytes -> pure $ TE.decodeUtf8 bytes) $ HM.lookup "user_id" actualMap
pure $ toJSON $ UserSession userID

-- Now ensure that when creating/updating the session table contents, ensure that the session_json column gets populated.
sessionToTable :: Value -> Session (SessionData SessionStorage) -> PersistentSessionFields
sessionToTable sessionDataJSON Session{..} =
let key = toFields $ unS sessionKey
authId = toFields sessionAuthId
session = toFields $ S.encode sessionData
createdAt = toFields sessionCreatedAt
accessedAt = toFields sessionAccessedAt
in (key, authId, session, createdAt, accessedAt)
in (key, authId, session, toFields $ Just sessionDataJSON, createdAt, accessedAt)

persistentSessionTableSelect :: Select PersistentSessionFields
persistentSessionTableSelect = selectTable persistentSessionTable

persistentSessionTableSelectByKey :: Text -> Select PersistentSessionFields
persistentSessionTableSelectByKey key = do
row@(rowKey, _, _, _, _) <- persistentSessionTableSelect
row@(rowKey, _, _, _, _, _) <- persistentSessionTableSelect
where_ (rowKey .== toFields key)
pure row

persistentSessionTableKeyPredicate :: Text -> PersistentSessionFields -> Column SqlBool
persistentSessionTableKeyPredicate key (rowKey, _, _, _, _) = rowKey .== toFields key
persistentSessionTableKeyPredicate key (rowKey, _, _, _, _, _) = rowKey .== toFields key

persistentSessionTableDeleteByKey :: Text -> ReaderT Connection IO ()
persistentSessionTableDeleteByKey key = do
Expand All @@ -102,14 +118,27 @@ persistentSessionTableDeleteByAuthId authId = do
connection <- ask
void $ liftIO $ runDelete connection $ Delete
{ dTable = persistentSessionTable
, dWhere = (\(_, rowAuthId, _, _, _) -> (nullableToMaybeFields rowAuthId) .=== (nullableToMaybeFields (toFields (Just authId))))
, dWhere = (\(_, rowAuthId, _, _, _, _) -> (nullableToMaybeFields rowAuthId) .=== (nullableToMaybeFields (toFields (Just authId))))
, dReturning = rCount
}

lookupSession :: SessionStorage -> SessionId (SessionData SessionStorage) -> ReaderT Connection IO (Maybe (Session (SessionData SessionStorage)))
lookupSession _ sessionId = do
connection <- ask
possibleSession <- liftIO $ fmap listToMaybe $ runSelect connection (persistentSessionTableSelectByKey $ toPathPiece sessionId)
possibleSession <- liftIO $ fmap listToMaybe $ runSelect connection (persistentSessionTableSelectByKey $ toPathPiece sessionId) :: ReaderT Connection IO (Maybe ((Text, Maybe ByteString, ByteString, Maybe Value, UTCTime, UTCTime)))
case possibleSession of
-- If the session_json column is null, then populate it.
Just session -> liftIO $ when (isNothing $ view _4 session) $ do
let sessionBytes = view _3 session
sessionMap <- either fail pure $ S.decode sessionBytes :: IO SessionMap
sessionDataJSON <- liftIO $ sessionMapToJSON sessionMap
void $ runUpdate connection $ Update
{ uTable = persistentSessionTable
, uUpdateWith = updateEasy (set _4 $ toFields $ Just sessionDataJSON)
, uWhere = (\(rowKey, _, _, _, _, _) -> rowKey .== toFields (view _1 session))
, uReturning = rCount
}
Nothing -> pure ()
liftIO $ traverse sessionFromTable possibleSession

instance Storage SessionStorage where
Expand All @@ -127,19 +156,23 @@ instance Storage SessionStorage where
existingSession <- lookupSession sto (sessionKey session)
case existingSession of
Just old -> throwSS $ SessionAlreadyExists old session
Nothing -> liftIO $ void $ runInsert connection $ Insert
Nothing -> liftIO $ do
sessionDataJSON <- liftIO $ sessionMapToJSON $ sessionData session
void $ runInsert connection $ Insert
{ iTable = persistentSessionTable
, iRows = [sessionToTable session]
, iRows = [sessionToTable sessionDataJSON session]
, iReturning = rCount
, iOnConflict = Nothing
}
replaceSession sto session = do
existingSession <- lookupSession sto (sessionKey session)
connection <- ask
case existingSession of
Just _ -> liftIO $ void $ runUpdate connection $ Update
Just _ -> liftIO $ do
sessionDataJSON <- liftIO $ sessionMapToJSON $ sessionData session
void $ runUpdate connection $ Update
{ uTable = persistentSessionTable
, uUpdateWith = updateEasy (const $ sessionToTable session)
, uUpdateWith = updateEasy (const $ sessionToTable sessionDataJSON session)
, uWhere = persistentSessionTableKeyPredicate $ toPathPiece $ sessionKey session
, uReturning = rCount
}
Expand Down Expand Up @@ -171,10 +204,13 @@ createCookie sessionState session = def
, setCookieValue = encodeUtf8 $ toPathPiece $ sessionKey session
}




newSessionForUser :: SessionState -> Text -> IO (Maybe SetCookie)
newSessionForUser sessionState userId = do
(sessionData, saveSessionToken) <- loadSession sessionState Nothing
savedResult <- saveSession sessionState saveSessionToken $ SessionMap $ M.insert "user_id" (encodeUtf8 userId) $ unSessionMap sessionData
savedResult <- saveSession sessionState saveSessionToken $ SessionMap $ HM.insert "user_id" (encodeUtf8 userId) $ unSessionMap sessionData
return $ fmap (createCookie sessionState) savedResult

getSessionIdFromCookie :: SessionState -> Maybe Text -> Maybe Text
Expand All @@ -187,7 +223,7 @@ getUserIdFromCookie :: SessionState -> Maybe Text -> IO (Maybe Text)
getUserIdFromCookie sessionState cookieContents = do
let possibleSessionId = getSessionIdFromCookie sessionState cookieContents
(sessionData, _) <- loadSession sessionState $ fmap encodeUtf8 possibleSessionId
case M.lookup "user_id" $ unSessionMap sessionData of
case HM.lookup "user_id" $ unSessionMap sessionData of
Just userId -> either (fail . show) (pure . Just) $ decodeUtf8' userId
Nothing -> pure Nothing

Expand Down
1 change: 1 addition & 0 deletions server/src/Utopia/Web/Database/Migrations.hs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ migrateDatabase verbose includeInitial pool = withResource pool $ \connection ->
, MigrationFile "003.sql" "./migrations/003.sql"
, MigrationFile "004.sql" "./migrations/004.sql"
, MigrationFile "005.sql" "./migrations/005.sql"
, MigrationFile "006.sql" "./migrations/006.sql"
]
let initialMigrationCommand = if includeInitial
then [MigrationFile "initial.sql" "./migrations/initial.sql"]
Expand Down
Loading