diff --git a/CHANGELOG.md b/CHANGELOG.md index 90706a19e..9696a64b3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,26 @@ +# 6.0.0 + +Version 6.0.0.8 + +Agent: +- enabled fast handshake support. +- batch-send multiple messages in each connection. +- resume subscriptions as soon as agent moves to foreground or as network connection resumes. +- "known" servers to determine whether to use SMP proxy. +- retry on SMP proxy NO_SESSION error. +- fixes to notification subscriptions. +- persistent server statistics. +- better concurrency. + +SMP server: +- reduce threads usage. +- additional statistics. +- improve disabling inactive clients. +- additional control port commands for monitoring. + +Notification server: +- support onion-only SMP servers. + # 5.8.2 Agent: diff --git a/package.yaml b/package.yaml index c58586aa6..fb6e19db7 100644 --- a/package.yaml +++ b/package.yaml @@ -1,5 +1,5 @@ name: simplexmq -version: 6.0.0.2 +version: 6.0.0.8 synopsis: SimpleXMQ message broker description: | This package includes <./docs/Simplex-Messaging-Server.html server>, diff --git a/simplexmq.cabal b/simplexmq.cabal index a32faa8fd..2fd32b1d7 100644 --- a/simplexmq.cabal +++ b/simplexmq.cabal @@ -5,7 +5,7 @@ cabal-version: 1.12 -- see: https://github.com/sol/hpack name: simplexmq -version: 6.0.0.2 +version: 6.0.0.8 synopsis: SimpleXMQ message broker description: This package includes <./docs/Simplex-Messaging-Server.html server>, <./docs/Simplex-Messaging-Client.html client> and diff --git a/src/Simplex/FileTransfer/Agent.hs b/src/Simplex/FileTransfer/Agent.hs index 725444086..d6ee75ae9 100644 --- a/src/Simplex/FileTransfer/Agent.hs +++ b/src/Simplex/FileTransfer/Agent.hs @@ -43,7 +43,7 @@ import Data.Either (partitionEithers, rights) import Data.Int (Int64) import Data.List (foldl', partition, sortOn) import qualified Data.List.NonEmpty as L -import Data.Map (Map) +import Data.Map.Strict (Map) import qualified Data.Map.Strict as M import Data.Maybe (mapMaybe) import qualified Data.Set as S @@ -184,7 +184,7 @@ runXFTPRcvWorker c srv Worker {doWork} = do cfg <- asks config forever $ do lift $ waitForWork doWork - atomically $ assertAgentForeground c + liftIO $ assertAgentForeground c runXFTPOperation cfg where runXFTPOperation :: AgentConfig -> AM () @@ -194,6 +194,7 @@ runXFTPRcvWorker c srv Worker {doWork} = do (fc@RcvFileChunk {userId, rcvFileId, rcvFileEntityId, digest, fileTmpPath, replicas = replica@RcvFileChunkReplica {rcvChunkReplicaId, server, delay} : _}, approvedRelays) -> do let ri' = maybe ri (\d -> ri {initialInterval = d, increaseAfter = 0}) delay withRetryIntervalLimit xftpConsecutiveRetries ri' $ \delay' loop -> do + liftIO $ waitWhileSuspended c liftIO $ waitForUserNetwork c atomically $ incXFTPServerStat c userId srv downloadAttempts downloadFileChunk fc replica approvedRelays @@ -204,7 +205,7 @@ runXFTPRcvWorker c srv Worker {doWork} = do when (serverHostError e) $ notify c rcvFileEntityId $ RFWARN e liftIO $ closeXFTPServerClient c userId server digest withStore' c $ \db -> updateRcvChunkReplicaDelay db rcvChunkReplicaId replicaDelay - atomically $ assertAgentForeground c + liftIO $ assertAgentForeground c loop retryDone e = do atomically . incXFTPServerStat c userId srv $ case e of @@ -220,7 +221,7 @@ runXFTPRcvWorker c srv Worker {doWork} = do chunkSpec = XFTPRcvChunkSpec chunkPath chSize (unFileDigest digest) relChunkPath = fileTmpPath takeFileName chunkPath agentXFTPDownloadChunk c userId digest replica chunkSpec - atomically $ waitUntilForeground c + liftIO $ waitUntilForeground c (entityId, complete, progress) <- withStore c $ \db -> runExceptT $ do liftIO $ updateRcvFileChunkReceived db (rcvChunkReplicaId replica) rcvChunkId relChunkPath RcvFile {size = FileSize currentSize, chunks, redirect} <- ExceptT $ getRcvFile db rcvFileId @@ -239,7 +240,7 @@ runXFTPRcvWorker c srv Worker {doWork} = do where ipAddressProtected' :: AM Bool ipAddressProtected' = do - cfg <- liftIO $ getNetworkConfig' c + cfg <- liftIO $ getFastNetworkConfig c pure $ ipAddressProtected cfg srv receivedSize :: [RcvFileChunk] -> Int64 receivedSize = foldl' (\sz ch -> sz + receivedChunkSize ch) 0 @@ -272,7 +273,7 @@ runXFTPRcvLocalWorker c Worker {doWork} = do cfg <- asks config forever $ do lift $ waitForWork doWork - atomically $ assertAgentForeground c + liftIO $ assertAgentForeground c runXFTPOperation cfg where runXFTPOperation :: AgentConfig -> AM () @@ -298,12 +299,12 @@ runXFTPRcvLocalWorker c Worker {doWork} = do Nothing -> do notify c rcvFileEntityId $ RFDONE fsSavePath lift $ forM_ tmpPath (removePath <=< toFSFilePath) - atomically $ waitUntilForeground c + liftIO $ waitUntilForeground c withStore' c (`updateRcvFileComplete` rcvFileId) Just RcvFileRedirect {redirectFileInfo, redirectDbId} -> do let RedirectFileInfo {size = redirectSize, digest = redirectDigest} = redirectFileInfo lift $ forM_ tmpPath (removePath <=< toFSFilePath) - atomically $ waitUntilForeground c + liftIO $ waitUntilForeground c withStore' c (`updateRcvFileComplete` rcvFileId) -- proceed with redirect yaml <- liftError (FILE . FILE_IO . show) (CF.readFile $ CryptoFile fsSavePath cfArgs) `agentFinally` (lift $ toFSFilePath fsSavePath >>= removePath) @@ -391,7 +392,7 @@ runXFTPSndPrepareWorker c Worker {doWork} = do cfg <- asks config forever $ do lift $ waitForWork doWork - atomically $ assertAgentForeground c + liftIO $ assertAgentForeground c runXFTPOperation cfg where runXFTPOperation :: AgentConfig -> AM () @@ -453,7 +454,7 @@ runXFTPSndPrepareWorker c Worker {doWork} = do SndFileChunkReplica {server} : _ -> Right server createChunk :: Int -> SndFileChunk -> AM (ProtocolServer 'PXFTP) createChunk numRecipients' ch = do - atomically $ assertAgentForeground c + liftIO $ assertAgentForeground c (replica, ProtoServerWithAuth srv _) <- tryCreate withStore' c $ \db -> createSndFileReplica db ch replica pure srv @@ -461,8 +462,9 @@ runXFTPSndPrepareWorker c Worker {doWork} = do tryCreate = do usedSrvs <- newTVarIO ([] :: [XFTPServer]) let AgentClient {xftpServers} = c - userSrvCount <- length <$> atomically (TM.lookup userId xftpServers) + userSrvCount <- liftIO $ length <$> TM.lookupIO userId xftpServers withRetryIntervalCount (riFast ri) $ \n _ loop -> do + liftIO $ waitWhileSuspended c liftIO $ waitForUserNetwork c let triedAllSrvs = n > userSrvCount createWithNextSrv usedSrvs @@ -472,7 +474,7 @@ runXFTPSndPrepareWorker c Worker {doWork} = do retryLoop loop triedAllSrvs e = do flip catchAgentError (\_ -> pure ()) $ do when (triedAllSrvs && serverHostError e) $ notify c sndFileEntityId $ SFWARN e - atomically $ assertAgentForeground c + liftIO $ assertAgentForeground c loop createWithNextSrv usedSrvs = do deleted <- withStore' c $ \db -> getSndFileDeleted db sndFileId @@ -492,7 +494,7 @@ runXFTPSndWorker c srv Worker {doWork} = do cfg <- asks config forever $ do lift $ waitForWork doWork - atomically $ assertAgentForeground c + liftIO $ assertAgentForeground c runXFTPOperation cfg where runXFTPOperation :: AgentConfig -> AM () @@ -502,6 +504,7 @@ runXFTPSndWorker c srv Worker {doWork} = do fc@SndFileChunk {userId, sndFileId, sndFileEntityId, filePrefixPath, digest, replicas = replica@SndFileChunkReplica {sndChunkReplicaId, server, delay} : _} -> do let ri' = maybe ri (\d -> ri {initialInterval = d, increaseAfter = 0}) delay withRetryIntervalLimit xftpConsecutiveRetries ri' $ \delay' loop -> do + liftIO $ waitWhileSuspended c liftIO $ waitForUserNetwork c atomically $ incXFTPServerStat c userId srv uploadAttempts uploadFileChunk cfg fc replica @@ -512,7 +515,7 @@ runXFTPSndWorker c srv Worker {doWork} = do when (serverHostError e) $ notify c sndFileEntityId $ SFWARN e liftIO $ closeXFTPServerClient c userId server digest withStore' c $ \db -> updateSndChunkReplicaDelay db sndChunkReplicaId replicaDelay - atomically $ assertAgentForeground c + liftIO $ assertAgentForeground c loop retryDone e = do atomically $ incXFTPServerStat c userId srv uploadErrs @@ -523,9 +526,9 @@ runXFTPSndWorker c srv Worker {doWork} = do fsFilePath <- lift $ toFSFilePath filePath unlessM (doesFileExist fsFilePath) $ throwE $ FILE NO_FILE let chunkSpec' = chunkSpec {filePath = fsFilePath} :: XFTPChunkSpec - atomically $ assertAgentForeground c + liftIO $ assertAgentForeground c agentXFTPUploadChunk c userId chunkDigest replica' chunkSpec' - atomically $ waitUntilForeground c + liftIO $ waitUntilForeground c sf@SndFile {sndFileEntityId, prefixPath, chunks} <- withStore c $ \db -> do updateSndChunkReplicaStatus db sndChunkReplicaId SFRSUploaded getSndFile db sndFileId @@ -663,7 +666,7 @@ runXFTPDelWorker c srv Worker {doWork} = do cfg <- asks config forever $ do lift $ waitForWork doWork - atomically $ assertAgentForeground c + liftIO $ assertAgentForeground c runXFTPOperation cfg where runXFTPOperation :: AgentConfig -> AM () @@ -674,6 +677,7 @@ runXFTPDelWorker c srv Worker {doWork} = do processDeletedReplica replica@DeletedSndChunkReplica {deletedSndChunkReplicaId, userId, server, chunkDigest, delay} = do let ri' = maybe ri (\d -> ri {initialInterval = d, increaseAfter = 0}) delay withRetryIntervalLimit xftpConsecutiveRetries ri' $ \delay' loop -> do + liftIO $ waitWhileSuspended c liftIO $ waitForUserNetwork c atomically $ incXFTPServerStat c userId srv deleteAttempts deleteChunkReplica @@ -684,7 +688,7 @@ runXFTPDelWorker c srv Worker {doWork} = do when (serverHostError e) $ notify c "" $ SFWARN e liftIO $ closeXFTPServerClient c userId server chunkDigest withStore' c $ \db -> updateDeletedSndChunkReplicaDelay db deletedSndChunkReplicaId replicaDelay - atomically $ assertAgentForeground c + liftIO $ assertAgentForeground c loop retryDone e = do atomically $ incXFTPServerStat c userId srv deleteErrs @@ -699,7 +703,7 @@ delWorkerInternalError c deletedSndChunkReplicaId e = do withStore' c $ \db -> deleteDeletedSndChunkReplica db deletedSndChunkReplicaId notify c "" $ SFERR e -assertAgentForeground :: AgentClient -> STM () +assertAgentForeground :: AgentClient -> IO () assertAgentForeground c = do throwWhenInactive c waitUntilForeground c diff --git a/src/Simplex/FileTransfer/Client/Agent.hs b/src/Simplex/FileTransfer/Client/Agent.hs index 86b093ee7..863a91ce1 100644 --- a/src/Simplex/FileTransfer/Client/Agent.hs +++ b/src/Simplex/FileTransfer/Client/Agent.hs @@ -53,9 +53,9 @@ defaultXFTPClientAgentConfig = data XFTPClientAgentError = XFTPClientAgentError XFTPServer XFTPClientError deriving (Show, Exception) -newXFTPAgent :: XFTPClientAgentConfig -> STM XFTPClientAgent +newXFTPAgent :: XFTPClientAgentConfig -> IO XFTPClientAgent newXFTPAgent config = do - xftpClients <- TM.empty + xftpClients <- TM.emptyIO pure XFTPClientAgent {xftpClients, config} type ME a = ExceptT XFTPClientAgentError IO a diff --git a/src/Simplex/FileTransfer/Client/Main.hs b/src/Simplex/FileTransfer/Client/Main.hs index aeac956e6..1eea6ef5a 100644 --- a/src/Simplex/FileTransfer/Client/Main.hs +++ b/src/Simplex/FileTransfer/Client/Main.hs @@ -43,7 +43,7 @@ import Data.Int (Int64) import Data.List (foldl', sortOn) import Data.List.NonEmpty (NonEmpty (..), nonEmpty) import qualified Data.List.NonEmpty as L -import Data.Map (Map) +import Data.Map.Strict (Map) import qualified Data.Map as M import Data.Maybe (fromMaybe, listToMaybe) import qualified Data.Text as T @@ -313,7 +313,7 @@ cliSendFileOpts SendOptions {filePath, outputDir, numRecipients, xftpServers, re pure (encPath, fdRcv, fdSnd, chunkSpecs, encSize) uploadFile :: TVar ChaChaDRG -> [XFTPChunkSpec] -> TVar [Int64] -> Int64 -> ExceptT CLIError IO [SentFileChunk] uploadFile g chunks uploadedChunks encSize = do - a <- atomically $ newXFTPAgent defaultXFTPClientAgentConfig + a <- liftIO $ newXFTPAgent defaultXFTPClientAgentConfig gen <- newTVarIO =<< liftIO newStdGen let xftpSrvs = fromMaybe defaultXFTPServers (nonEmpty xftpServers) srvs <- liftIO $ replicateM (length chunks) $ getXFTPServer gen xftpSrvs @@ -429,7 +429,7 @@ cliReceiveFile ReceiveOptions {fileDescription, filePath, retryCount, tempPath, receive (ValidFileDescription FileDescription {size, digest, key, nonce, chunks}) = do encPath <- getEncPath tempPath "xftp" createDirectory encPath - a <- atomically $ newXFTPAgent defaultXFTPClientAgentConfig + a <- liftIO $ newXFTPAgent defaultXFTPClientAgentConfig liftIO $ printNoNewLine "Downloading file..." downloadedChunks <- newTVarIO [] let srv FileChunk {replicas} = case replicas of @@ -494,7 +494,7 @@ cliDeleteFile DeleteOptions {fileDescription, retryCount, yes} = do where deleteFile :: ValidFileDescription 'FSender -> ExceptT CLIError IO () deleteFile (ValidFileDescription FileDescription {chunks}) = do - a <- atomically $ newXFTPAgent defaultXFTPClientAgentConfig + a <- liftIO $ newXFTPAgent defaultXFTPClientAgentConfig forM_ chunks $ deleteFileChunk a liftIO $ do printNoNewLine "File deleted!" diff --git a/src/Simplex/FileTransfer/Description.hs b/src/Simplex/FileTransfer/Description.hs index d5b5e5105..c702a177f 100644 --- a/src/Simplex/FileTransfer/Description.hs +++ b/src/Simplex/FileTransfer/Description.hs @@ -52,7 +52,7 @@ import Data.Int (Int64) import Data.List (foldl', sortOn) import Data.List.NonEmpty (NonEmpty (..)) import qualified Data.List.NonEmpty as L -import Data.Map (Map) +import Data.Map.Strict (Map) import qualified Data.Map as M import Data.Maybe (fromMaybe) import Data.String diff --git a/src/Simplex/FileTransfer/Server.hs b/src/Simplex/FileTransfer/Server.hs index 24dcc5e38..819be9a81 100644 --- a/src/Simplex/FileTransfer/Server.hs +++ b/src/Simplex/FileTransfer/Server.hs @@ -112,7 +112,7 @@ xftpServer cfg@XFTPServerConfig {xftpPort, transportConfig, inactiveClientExpira Right pk' -> pure pk' Left e -> putStrLn ("servers has no valid key: " <> show e) >> exitFailure env <- ask - sessions <- atomically TM.empty + sessions <- liftIO TM.emptyIO let cleanup sessionId = atomically $ TM.delete sessionId sessions liftIO . runHTTP2Server started xftpPort defaultHTTP2BufferSize serverParams transportConfig inactiveClientExpiration cleanup $ \sessionId sessionALPN r sendResponse -> do reqBody <- getHTTP2Body r xftpBlockSize @@ -576,7 +576,7 @@ incFileStat statSel = do saveServerStats :: M () saveServerStats = asks (serverStatsBackupFile . config) - >>= mapM_ (\f -> asks serverStats >>= atomically . getFileServerStatsData >>= liftIO . saveStats f) + >>= mapM_ (\f -> asks serverStats >>= liftIO . getFileServerStatsData >>= liftIO . saveStats f) where saveStats f stats = do logInfo $ "saving server stats to file " <> T.pack f diff --git a/src/Simplex/FileTransfer/Server/Env.hs b/src/Simplex/FileTransfer/Server/Env.hs index f8a6bc996..1fa399a2a 100644 --- a/src/Simplex/FileTransfer/Server/Env.hs +++ b/src/Simplex/FileTransfer/Server/Env.hs @@ -11,7 +11,6 @@ module Simplex.FileTransfer.Server.Env where import Control.Logger.Simple import Control.Monad -import Control.Monad.IO.Unlift import Crypto.Random import Data.Int (Int64) import Data.List.NonEmpty (NonEmpty) @@ -105,17 +104,17 @@ supportedXFTPhandshakes = ["xftp/1"] newXFTPServerEnv :: XFTPServerConfig -> IO XFTPEnv newXFTPServerEnv config@XFTPServerConfig {storeLogFile, fileSizeQuota, caCertificateFile, certificateFile, privateKeyFile, transportConfig} = do - random <- liftIO C.newRandom - store <- atomically newFileStore - storeLog <- liftIO $ mapM (`readWriteFileStore` store) storeLogFile + random <- C.newRandom + store <- newFileStore + storeLog <- mapM (`readWriteFileStore` store) storeLogFile used <- countUsedStorage <$> readTVarIO (files store) atomically $ writeTVar (usedStorage store) used forM_ fileSizeQuota $ \quota -> do logInfo $ "Total / available storage: " <> tshow quota <> " / " <> tshow (quota - used) when (quota < used) $ logInfo "WARNING: storage quota is less than used storage, no files can be uploaded!" - tlsServerParams <- liftIO $ loadTLSServerParams caCertificateFile certificateFile privateKeyFile (alpn transportConfig) - Fingerprint fp <- liftIO $ loadFingerprint caCertificateFile - serverStats <- atomically . newFileServerStats =<< liftIO getCurrentTime + tlsServerParams <- loadTLSServerParams caCertificateFile certificateFile privateKeyFile (alpn transportConfig) + Fingerprint fp <- loadFingerprint caCertificateFile + serverStats <- newFileServerStats =<< getCurrentTime pure XFTPEnv {config, store, storeLog, random, tlsServerParams, serverIdentity = C.KeyHash fp, serverStats} countUsedStorage :: M.Map k FileRec -> Int64 diff --git a/src/Simplex/FileTransfer/Server/Stats.hs b/src/Simplex/FileTransfer/Server/Stats.hs index 08813dc2a..1178dd5f6 100644 --- a/src/Simplex/FileTransfer/Server/Stats.hs +++ b/src/Simplex/FileTransfer/Server/Stats.hs @@ -43,34 +43,34 @@ data FileServerStatsData = FileServerStatsData } deriving (Show) -newFileServerStats :: UTCTime -> STM FileServerStats +newFileServerStats :: UTCTime -> IO FileServerStats newFileServerStats ts = do - fromTime <- newTVar ts - filesCreated <- newTVar 0 - fileRecipients <- newTVar 0 - filesUploaded <- newTVar 0 - filesExpired <- newTVar 0 - filesDeleted <- newTVar 0 + fromTime <- newTVarIO ts + filesCreated <- newTVarIO 0 + fileRecipients <- newTVarIO 0 + filesUploaded <- newTVarIO 0 + filesExpired <- newTVarIO 0 + filesDeleted <- newTVarIO 0 filesDownloaded <- newPeriodStats - fileDownloads <- newTVar 0 - fileDownloadAcks <- newTVar 0 - filesCount <- newTVar 0 - filesSize <- newTVar 0 + fileDownloads <- newTVarIO 0 + fileDownloadAcks <- newTVarIO 0 + filesCount <- newTVarIO 0 + filesSize <- newTVarIO 0 pure FileServerStats {fromTime, filesCreated, fileRecipients, filesUploaded, filesExpired, filesDeleted, filesDownloaded, fileDownloads, fileDownloadAcks, filesCount, filesSize} -getFileServerStatsData :: FileServerStats -> STM FileServerStatsData +getFileServerStatsData :: FileServerStats -> IO FileServerStatsData getFileServerStatsData s = do - _fromTime <- readTVar $ fromTime (s :: FileServerStats) - _filesCreated <- readTVar $ filesCreated s - _fileRecipients <- readTVar $ fileRecipients s - _filesUploaded <- readTVar $ filesUploaded s - _filesExpired <- readTVar $ filesExpired s - _filesDeleted <- readTVar $ filesDeleted s + _fromTime <- readTVarIO $ fromTime (s :: FileServerStats) + _filesCreated <- readTVarIO $ filesCreated s + _fileRecipients <- readTVarIO $ fileRecipients s + _filesUploaded <- readTVarIO $ filesUploaded s + _filesExpired <- readTVarIO $ filesExpired s + _filesDeleted <- readTVarIO $ filesDeleted s _filesDownloaded <- getPeriodStatsData $ filesDownloaded s - _fileDownloads <- readTVar $ fileDownloads s - _fileDownloadAcks <- readTVar $ fileDownloadAcks s - _filesCount <- readTVar $ filesCount s - _filesSize <- readTVar $ filesSize s + _fileDownloads <- readTVarIO $ fileDownloads s + _fileDownloadAcks <- readTVarIO $ fileDownloadAcks s + _filesCount <- readTVarIO $ filesCount s + _filesSize <- readTVarIO $ filesSize s pure FileServerStatsData {_fromTime, _filesCreated, _fileRecipients, _filesUploaded, _filesExpired, _filesDeleted, _filesDownloaded, _fileDownloads, _fileDownloadAcks, _filesCount, _filesSize} setFileServerStats :: FileServerStats -> FileServerStatsData -> STM () diff --git a/src/Simplex/FileTransfer/Server/Store.hs b/src/Simplex/FileTransfer/Server/Store.hs index aa8eaa932..b56b516aa 100644 --- a/src/Simplex/FileTransfer/Server/Store.hs +++ b/src/Simplex/FileTransfer/Server/Store.hs @@ -55,11 +55,11 @@ instance StrEncoding FileRecipient where strEncode (FileRecipient rId rKey) = strEncode rId <> ":" <> strEncode rKey strP = FileRecipient <$> strP <* A.char ':' <*> strP -newFileStore :: STM FileStore +newFileStore :: IO FileStore newFileStore = do - files <- TM.empty - recipients <- TM.empty - usedStorage <- newTVar 0 + files <- TM.emptyIO + recipients <- TM.emptyIO + usedStorage <- newTVarIO 0 pure FileStore {files, recipients, usedStorage} addFile :: FileStore -> SenderId -> FileInfo -> SystemTime -> STM (Either XFTPErrorType ()) diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 94ed7f44f..17d11246c 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -33,6 +33,7 @@ module Simplex.Messaging.Agent AgentClient (..), AE, SubscriptionsInfo (..), + MsgReq, getSMPAgentClient, getSMPAgentClient_, disconnectAgentClient, @@ -106,6 +107,7 @@ module Simplex.Messaging.Agent rcConnectHost, rcConnectCtrl, rcDiscoverCtrl, + getAgentSubsTotal, getAgentServersSummary, resetAgentServersStats, foregroundAgent, @@ -129,6 +131,7 @@ import Data.Bifunctor (bimap, first) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Composition ((.:), (.:.), (.::), (.::.)) +import Data.Containers.ListUtils (nubOrd) import Data.Either (isRight, rights) import Data.Foldable (foldl', toList) import Data.Functor (($>)) @@ -205,7 +208,7 @@ getSMPAgentClient_ clientId cfg initServers@InitialAgentServers {smp, xftp} stor runAgent = do liftIO $ checkServers "SMP" smp >> checkServers "XFTP" xftp currentTs <- liftIO getCurrentTime - c@AgentClient {acThread} <- atomically . newAgentClient clientId initServers currentTs =<< ask + c@AgentClient {acThread} <- liftIO . newAgentClient clientId initServers currentTs =<< ask t <- runAgentThreads c `forkFinally` const (liftIO $ disconnectAgentClient c) atomically . writeTVar acThread . Just =<< mkWeakThreadId t pure c @@ -233,29 +236,30 @@ logServersStats c = do liftIO $ threadDelay' delay int <- asks (logStatsInterval . config) forever $ do + liftIO $ waitUntilActive c saveServersStats c liftIO $ threadDelay' int saveServersStats :: AgentClient -> AM' () -saveServersStats c@AgentClient {subQ, smpServersStats, xftpServersStats} = do - sss <- mapM (lift . getAgentSMPServerStats) =<< readTVarIO smpServersStats - xss <- mapM (lift . getAgentXFTPServerStats) =<< readTVarIO xftpServersStats - let stats = AgentPersistedServerStats {smpServersStats = sss, xftpServersStats = xss} +saveServersStats c@AgentClient {subQ, smpServersStats, xftpServersStats, ntfServersStats} = do + sss <- mapM (liftIO . getAgentSMPServerStats) =<< readTVarIO smpServersStats + xss <- mapM (liftIO . getAgentXFTPServerStats) =<< readTVarIO xftpServersStats + nss <- mapM (liftIO . getAgentNtfServerStats) =<< readTVarIO ntfServersStats + let stats = AgentPersistedServerStats {smpServersStats = sss, xftpServersStats = xss, ntfServersStats = OptionalMap nss} tryAgentError' (withStore' c (`updateServersStats` stats)) >>= \case Left e -> atomically $ writeTBQueue subQ ("", "", AEvt SAEConn $ ERR $ INTERNAL $ show e) Right () -> pure () restoreServersStats :: AgentClient -> AM' () -restoreServersStats c@AgentClient {smpServersStats, xftpServersStats, srvStatsStartedAt} = do +restoreServersStats c@AgentClient {smpServersStats, xftpServersStats, ntfServersStats, srvStatsStartedAt} = do tryAgentError' (withStore c getServersStats) >>= \case Left e -> atomically $ writeTBQueue (subQ c) ("", "", AEvt SAEConn $ ERR $ INTERNAL $ show e) Right (startedAt, Nothing) -> atomically $ writeTVar srvStatsStartedAt startedAt - Right (startedAt, Just AgentPersistedServerStats {smpServersStats = sss, xftpServersStats = xss}) -> do + Right (startedAt, Just AgentPersistedServerStats {smpServersStats = sss, xftpServersStats = xss, ntfServersStats = OptionalMap nss}) -> do atomically $ writeTVar srvStatsStartedAt startedAt - sss' <- mapM (atomically . newAgentSMPServerStats') sss - atomically $ writeTVar smpServersStats sss' - xss' <- mapM (atomically . newAgentXFTPServerStats') xss - atomically $ writeTVar xftpServersStats xss' + atomically . writeTVar smpServersStats =<< mapM (atomically . newAgentSMPServerStats') sss + atomically . writeTVar xftpServersStats =<< mapM (atomically . newAgentXFTPServerStats') xss + atomically . writeTVar ntfServersStats =<< mapM (atomically . newAgentNtfServerStats') nss disconnectAgentClient :: AgentClient -> IO () disconnectAgentClient c@AgentClient {agentEnv = Env {ntfSupervisor = ns, xftpAgent = xa}} = do @@ -339,7 +343,7 @@ prepareConnectionToJoin :: AgentClient -> UserId -> Bool -> ConnectionRequestUri prepareConnectionToJoin c userId enableNtfs = withAgentEnv c .: newConnToJoin c userId "" enableNtfs -- | Join SMP agent connection (JOIN command). -joinConnection :: AgentClient -> UserId -> Maybe ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> AE ConnId +joinConnection :: AgentClient -> UserId -> Maybe ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> AE (ConnId, SndQueueSecured) joinConnection c userId Nothing enableNtfs = withAgentEnv c .:: joinConn c userId "" False enableNtfs joinConnection c userId (Just connId) enableNtfs = withAgentEnv c .:: joinConn c userId connId True enableNtfs {-# INLINE joinConnection #-} @@ -350,7 +354,7 @@ allowConnection c = withAgentEnv c .:. allowConnection' c {-# INLINE allowConnection #-} -- | Accept contact after REQ notification (ACPT command) -acceptContact :: AgentClient -> Bool -> ConfirmationId -> ConnInfo -> PQSupport -> SubscriptionMode -> AE ConnId +acceptContact :: AgentClient -> Bool -> ConfirmationId -> ConnInfo -> PQSupport -> SubscriptionMode -> AE (ConnId, SndQueueSecured) acceptContact c enableNtfs = withAgentEnv c .:: acceptContact' c "" enableNtfs {-# INLINE acceptContact #-} @@ -375,7 +379,7 @@ getConnectionMessage c = withAgentEnv c . getConnectionMessage' c {-# INLINE getConnectionMessage #-} -- | Get connection message for received notification -getNotificationMessage :: AgentClient -> C.CbNonce -> ByteString -> AE (NotificationInfo, [SMPMsgMeta]) +getNotificationMessage :: AgentClient -> C.CbNonce -> ByteString -> AE (NotificationInfo, Maybe SMPMsgMeta) getNotificationMessage c = withAgentEnv c .: getNotificationMessage' c {-# INLINE getNotificationMessage #-} @@ -392,6 +396,10 @@ sendMessage :: AgentClient -> ConnId -> PQEncryption -> MsgFlags -> MsgBody -> A sendMessage c = withAgentEnv c .:: sendMessage' c {-# INLINE sendMessage #-} +-- When sending multiple messages to the same connection, +-- only the first MsgReq for this connection should have non-empty ConnId. +-- All subsequent MsgReq in traversable for this connection must be empty. +-- This is done to optimize processing by grouping all messages to one connection together. type MsgReq = (ConnId, PQEncryption, MsgFlags, MsgBody) -- | Send multiple messages to different connections (SEND command) @@ -783,7 +791,7 @@ newConnToJoin c userId connId enableNtfs cReq pqSup = case cReq of cData = ConnData {userId, connId, connAgentVersion, enableNtfs, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk, pqSupport} withStore c $ \db -> createNewConn db g cData SCMInvitation -joinConn :: AgentClient -> UserId -> ConnId -> Bool -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> AM ConnId +joinConn :: AgentClient -> UserId -> ConnId -> Bool -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> AM (ConnId, SndQueueSecured) joinConn c userId connId hasNewConn enableNtfs cReq cInfo pqSupport subMode = do srv <- case cReq of CRInvitationUri ConnReqUriData {crSmpQueues = q :| _} _ -> @@ -842,7 +850,7 @@ versionPQSupport_ :: VersionSMPA -> Maybe CR.VersionE2E -> PQSupport versionPQSupport_ agentV e2eV_ = PQSupport $ agentV >= pqdrSMPAgentVersion && maybe True (>= CR.pqRatchetE2EEncryptVersion) e2eV_ {-# INLINE versionPQSupport_ #-} -joinConnSrv :: AgentClient -> UserId -> ConnId -> Bool -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> SMPServerWithAuth -> AM ConnId +joinConnSrv :: AgentClient -> UserId -> ConnId -> Bool -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> SMPServerWithAuth -> AM (ConnId, SndQueueSecured) joinConnSrv c userId connId hasNewConn enableNtfs inv@CRInvitationUri {} cInfo pqSup subMode srv = withInvLock c (strEncode inv) "joinConnSrv" $ do (cData, q, _, rc, e2eSndParams) <- startJoinInvitation userId connId Nothing enableNtfs inv pqSup @@ -859,7 +867,7 @@ joinConnSrv c userId connId hasNewConn enableNtfs inv@CRInvitationUri {} cInfo p -- otherwise we would need to manage retries here to avoid SndQueue recreated with a different key, -- similar to how joinConnAsync does that. tryError (secureConfirmQueue c cData' sq srv cInfo (Just e2eSndParams) subMode) >>= \case - Right _ -> pure connId' + Right sqSecured -> pure (connId', sqSecured) Left e -> do -- possible improvement: recovery for failure on network timeout, see rfcs/2022-04-20-smp-conf-timeout-recovery.md void $ withStore' c $ \db -> deleteConn db Nothing connId' @@ -869,10 +877,10 @@ joinConnSrv c userId connId hasNewConn enableNtfs cReqUri@CRContactUri {} cInfo Just (qInfo, vrsn) -> do (connId', cReq) <- newConnSrv c userId connId hasNewConn enableNtfs SCMInvitation Nothing (CR.IKNoPQ pqSup) subMode srv void $ sendInvitation c userId qInfo vrsn cReq cInfo - pure connId' + pure (connId', False) Nothing -> throwE $ AGENT A_VERSION -joinConnSrvAsync :: AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> SMPServerWithAuth -> AM () +joinConnSrvAsync :: AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> SMPServerWithAuth -> AM SndQueueSecured joinConnSrvAsync c userId connId enableNtfs inv@CRInvitationUri {} cInfo pqSupport subMode srv = do SomeConn cType conn <- withStore c (`getConn` connId) case conn of @@ -880,7 +888,7 @@ joinConnSrvAsync c userId connId enableNtfs inv@CRInvitationUri {} cInfo pqSuppo SndConnection _ sq -> doJoin $ Just sq _ -> throwE $ CMD PROHIBITED $ "joinConnSrvAsync: bad connection " <> show cType where - doJoin :: Maybe SndQueue -> AM () + doJoin :: Maybe SndQueue -> AM SndQueueSecured doJoin sq_ = do (cData, sq, _, rc, e2eSndParams) <- startJoinInvitation userId connId sq_ enableNtfs inv pqSupport sq' <- withStore c $ \db -> runExceptT $ do @@ -907,18 +915,14 @@ createReplyQueue c ConnData {userId, connId, enableNtfs} SndQueue {smpClientVers allowConnection' :: AgentClient -> ConnId -> ConfirmationId -> ConnInfo -> AM () allowConnection' c connId confId ownConnInfo = withConnLock c connId "allowConnection" $ do withStore c (`getConn` connId) >>= \case - SomeConn _ (RcvConnection _ rq@RcvQueue {server, rcvId, e2ePrivKey, smpClientVersion = v}) -> do - senderKey <- withStore c $ \db -> runExceptT $ do - AcceptedConfirmation {ratchetState, senderConf = SMPConfirmation {senderKey, e2ePubKey, smpClientVersion = v'}} <- ExceptT $ acceptConfirmation db confId ownConnInfo - liftIO $ createRatchet db connId ratchetState - let dhSecret = C.dh' e2ePubKey e2ePrivKey - liftIO $ setRcvQueueConfirmedE2E db rq dhSecret $ min v v' - pure senderKey + SomeConn _ (RcvConnection _ RcvQueue {server, rcvId}) -> do + AcceptedConfirmation {senderConf = SMPConfirmation {senderKey}} <- + withStore c $ \db -> acceptConfirmation db confId ownConnInfo enqueueCommand c "" connId (Just server) . AInternalCommand $ ICAllowSecure rcvId senderKey _ -> throwE $ CMD PROHIBITED "allowConnection" -- | Accept contact (ACPT command) in Reader monad -acceptContact' :: AgentClient -> ConnId -> Bool -> InvitationId -> ConnInfo -> PQSupport -> SubscriptionMode -> AM ConnId +acceptContact' :: AgentClient -> ConnId -> Bool -> InvitationId -> ConnInfo -> PQSupport -> SubscriptionMode -> AM (ConnId, SndQueueSecured) acceptContact' c connId enableNtfs invId ownConnInfo pqSupport subMode = withConnLock c connId "acceptContact" $ do Invitation {contactConnId, connReq} <- withStore c (`getInvitation` invId) withStore c (`getConn` contactConnId) >>= \case @@ -956,7 +960,7 @@ subscribeConnections' c connIds = do errs' = M.map (Left . storeError) errs (subRs, rcvQs) = M.mapEither rcvQueueOrResult cs mapM_ (mapM_ (\(cData, sqs) -> mapM_ (lift . resumeMsgDelivery c cData) sqs) . sndQueue) cs - mapM_ (resumeConnCmds c) $ M.keys cs + lift $ resumeConnCmds c $ M.keys cs rcvRs <- lift $ connResults . fst <$> subscribeQueues c (concat $ M.elems rcvQs) ns <- asks ntfSupervisor tkn <- readTVarIO (ntfTkn ns) @@ -1036,7 +1040,7 @@ getConnectionMessage' c connId = do SndConnection _ _ -> throwE $ CONN SIMPLEX NewConnection _ -> throwE $ CMD PROHIBITED "getConnectionMessage: NewConnection" -getNotificationMessage' :: AgentClient -> C.CbNonce -> ByteString -> AM (NotificationInfo, [SMPMsgMeta]) +getNotificationMessage' :: AgentClient -> C.CbNonce -> ByteString -> AM (NotificationInfo, Maybe SMPMsgMeta) getNotificationMessage' c nonce encNtfInfo = do withStore' c getActiveNtfToken >>= \case Just NtfToken {ntfDhSecret = Just dhSecret} -> do @@ -1044,22 +1048,9 @@ getNotificationMessage' c nonce encNtfInfo = do PNMessageData {smpQueue, ntfTs, nmsgNonce, encNMsgMeta} <- liftEither (parse strP (INTERNAL "error parsing PNMessageData") ntfData) (ntfConnId, rcvNtfDhSecret) <- withStore c (`getNtfRcvQueue` smpQueue) ntfMsgMeta <- (eitherToMaybe . smpDecode <$> agentCbDecrypt rcvNtfDhSecret nmsgNonce encNMsgMeta) `catchAgentError` \_ -> pure Nothing - maxMsgs <- asks $ ntfMaxMessages . config - (NotificationInfo {ntfConnId, ntfTs, ntfMsgMeta},) <$> getNtfMessages ntfConnId ntfMsgMeta maxMsgs + msgMeta <- getConnectionMessage' c ntfConnId + pure (NotificationInfo {ntfConnId, ntfTs, ntfMsgMeta}, msgMeta) _ -> throwE $ CMD PROHIBITED "getNotificationMessage" - where - getNtfMessages ntfConnId nMeta = getMsg - where - getMsg 0 = pure [] - getMsg n = - getConnectionMessage' c ntfConnId >>= \case - Just m - | lastMsg m -> pure [m] - | otherwise -> (m :) <$> getMsg (n - 1) - Nothing -> pure [] - lastMsg SMP.SMPMsgMeta {msgId, msgTs, msgFlags} = case nMeta of - Just SMP.NMsgMeta {msgId = msgId', msgTs = msgTs'} -> msgId == msgId' || msgTs > msgTs' - Nothing -> SMP.notification msgFlags -- | Send message to the connection (SEND command) in Reader monad sendMessage' :: AgentClient -> ConnId -> PQEncryption -> MsgFlags -> MsgBody -> AM (AgentMsgId, PQEncryption) @@ -1073,38 +1064,49 @@ sendMessages' c = sendMessagesB' c . map Right sendMessagesB' :: forall t. Traversable t => AgentClient -> t (Either AgentErrorType MsgReq) -> AM (t (Either AgentErrorType (AgentMsgId, PQEncryption))) sendMessagesB' c reqs = do - connIds <- liftEither $ foldl' addConnId (Right S.empty) reqs + (_, connIds) <- liftEither $ foldl' addConnId (Right ("", S.empty)) reqs lift $ sendMessagesB_ c reqs connIds where - addConnId s@(Right s') (Right (connId, _, _, _)) - | B.null connId = s - | connId `S.notMember` s' = Right $ S.insert connId s' - | otherwise = Left $ INTERNAL "sendMessages: duplicate connection ID" - addConnId s _ = s + addConnId acc@(Right (prevId, s)) (Right (connId, _, _, _)) + | B.null connId = if B.null prevId then Left $ INTERNAL "sendMessages: empty first connId" else acc + | connId `S.member` s = Left $ INTERNAL "sendMessages: duplicate connId" + | otherwise = Right (connId, S.insert connId s) + addConnId acc _ = acc sendMessagesB_ :: forall t. Traversable t => AgentClient -> t (Either AgentErrorType MsgReq) -> Set ConnId -> AM' (t (Either AgentErrorType (AgentMsgId, PQEncryption))) sendMessagesB_ c reqs connIds = withConnLocks c connIds "sendMessages" $ do - reqs' <- withStoreBatch c (\db -> fmap (bindRight $ \req@(connId, _, _, _) -> bimap storeError (req,) <$> getConn db connId) reqs) + prev <- newTVarIO Nothing + reqs' <- withStoreBatch c $ \db -> fmap (bindRight $ getConn_ db prev) reqs let (toEnable, reqs'') = mapAccumL prepareConn [] reqs' - void $ withStoreBatch' c $ \db -> map (\connId -> setConnPQSupport db connId PQSupportOn) toEnable + void $ withStoreBatch' c $ \db -> map (\connId -> setConnPQSupport db connId PQSupportOn) $ S.toList toEnable enqueueMessagesB c reqs'' where - prepareConn :: [ConnId] -> Either AgentErrorType (MsgReq, SomeConn) -> ([ConnId], Either AgentErrorType (ConnData, NonEmpty SndQueue, Maybe PQEncryption, MsgFlags, AMessage)) - prepareConn acc (Left e) = (acc, Left e) - prepareConn acc (Right ((_, pqEnc, msgFlags, msg), SomeConn _ conn)) = case conn of + getConn_ :: DB.Connection -> TVar (Maybe (Either AgentErrorType SomeConn)) -> MsgReq -> IO (Either AgentErrorType (MsgReq, SomeConn)) + getConn_ db prev req@(connId, _, _, _) = + (req,) <$$> + if B.null connId + then fromMaybe (Left $ INTERNAL "sendMessagesB_: empty prev connId") <$> readTVarIO prev + else do + conn <- first storeError <$> getConn db connId + conn <$ atomically (writeTVar prev $ Just conn) + prepareConn :: Set ConnId -> Either AgentErrorType (MsgReq, SomeConn) -> (Set ConnId, Either AgentErrorType (ConnData, NonEmpty SndQueue, Maybe PQEncryption, MsgFlags, AMessage)) + prepareConn s (Left e) = (s, Left e) + prepareConn s (Right ((_, pqEnc, msgFlags, msg), SomeConn _ conn)) = case conn of DuplexConnection cData _ sqs -> prepareMsg cData sqs SndConnection cData sq -> prepareMsg cData [sq] - _ -> (acc, Left $ CONN SIMPLEX) + _ -> (s, Left $ CONN SIMPLEX) where - prepareMsg :: ConnData -> NonEmpty SndQueue -> ([ConnId], Either AgentErrorType (ConnData, NonEmpty SndQueue, Maybe PQEncryption, MsgFlags, AMessage)) + prepareMsg :: ConnData -> NonEmpty SndQueue -> (Set ConnId, Either AgentErrorType (ConnData, NonEmpty SndQueue, Maybe PQEncryption, MsgFlags, AMessage)) prepareMsg cData@ConnData {connId, pqSupport} sqs - | ratchetSyncSendProhibited cData = (acc, Left $ CMD PROHIBITED "sendMessagesB: send prohibited") + | ratchetSyncSendProhibited cData = (s, Left $ CMD PROHIBITED "sendMessagesB: send prohibited") -- connection is only updated if PQ encryption was disabled, and now it has to be enabled. -- support for PQ encryption (small message envelopes) will not be disabled when message is sent. | pqEnc == PQEncOn && pqSupport == PQSupportOff = let cData' = cData {pqSupport = PQSupportOn} :: ConnData - in (connId : acc, Right (cData', sqs, Just pqEnc, msgFlags, A_MSG msg)) - | otherwise = (acc, Right (cData, sqs, Just pqEnc, msgFlags, A_MSG msg)) + in (S.insert connId s, mkReq cData') + | otherwise = (s, mkReq cData) + where + mkReq cData' = Right (cData', sqs, Just pqEnc, msgFlags, A_MSG msg) -- / async command processing v v v @@ -1117,13 +1119,10 @@ resumeSrvCmds :: AgentClient -> Maybe SMPServer -> AM' () resumeSrvCmds = void .: getAsyncCmdWorker False {-# INLINE resumeSrvCmds #-} -resumeConnCmds :: AgentClient -> ConnId -> AM () -resumeConnCmds c connId = - unlessM connQueued $ - withStore' c (`getPendingCommandServers` connId) - >>= mapM_ (lift . resumeSrvCmds c) - where - connQueued = atomically $ isJust <$> TM.lookupInsert connId True (connCmdsQueued c) +resumeConnCmds :: AgentClient -> [ConnId] -> AM' () +resumeConnCmds c connIds = do + srvs <- nubOrd . concat . rights <$> withStoreBatch' c (\db -> fmap (getPendingCommandServers db) connIds) + mapM_ (resumeSrvCmds c) srvs getAsyncCmdWorker :: Bool -> AgentClient -> Maybe SMPServer -> AM' Worker getAsyncCmdWorker hasWork c server = @@ -1135,7 +1134,7 @@ runCommandProcessing c@AgentClient {subQ} server_ Worker {doWork} = do forever $ do atomically $ endAgentOperation c AOSndNetwork lift $ waitForWork doWork - atomically $ throwWhenInactive c + liftIO $ throwWhenInactive c atomically $ beginAgentOperation c AOSndNetwork withWork c doWork (`getPendingServerCommand` server_) $ runProcessCmd (riFast ri) where @@ -1155,8 +1154,8 @@ runCommandProcessing c@AgentClient {subQ} server_ Worker {doWork} = do let initUsed = [qServer q] usedSrvs <- newTVarIO initUsed tryCommand . withNextSrv c userId usedSrvs initUsed $ \srv -> do - joinConnSrvAsync c userId connId enableNtfs cReq connInfo pqEnc subMode srv - notify OK + sqSecured <- joinConnSrvAsync c userId connId enableNtfs cReq connInfo pqEnc subMode srv + notify $ JOINED sqSecured LET confId ownCInfo -> withServer' . tryCommand $ allowConnection' c connId confId ownCInfo >> notify OK ACK msgId rcptInfo_ -> withServer' . tryCommand $ ackMessage' c connId msgId rcptInfo_ >> notify OK SWCH -> @@ -1252,7 +1251,9 @@ runCommandProcessing c@AgentClient {subQ} server_ Worker {doWork} = do withStore c (`getConn` connId) >>= \case SomeConn _ conn@DuplexConnection {} -> a conn _ -> internalErr "command requires duplex connection" - tryCommand action = withRetryInterval ri $ \_ loop -> + tryCommand action = withRetryInterval ri $ \_ loop -> do + liftIO $ waitWhileSuspended c + liftIO $ waitForUserNetwork c tryError action >>= \case Left e | temporaryOrHostError e -> retrySndOp c loop @@ -1360,8 +1361,8 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} ConnData {connId} sq@SndQueue {userI forever $ do atomically $ endAgentOperation c AOSndNetwork lift $ waitForWork doWork - atomically $ throwWhenInactive c - atomically $ throwWhenNoDelivery c sq + liftIO $ throwWhenInactive c + liftIO $ throwWhenNoDelivery c sq atomically $ beginAgentOperation c AOSndNetwork withWork c doWork (\db -> getPendingQueueMsg db connId sq) $ \(rq_, PendingMsgData {msgId, msgType, msgBody, pqEncryption, msgFlags, msgRetryState, internalTs}) -> do @@ -1369,6 +1370,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} ConnData {connId} sq@SndQueue {userI let mId = unId msgId ri' = maybe id updateRetryInterval2 msgRetryState ri withRetryLock2 ri' qLock $ \riState loop -> do + liftIO $ waitWhileSuspended c liftIO $ waitForUserNetwork c resp <- tryError $ case msgType of AM_CONN_INFO -> sendConfirmation c sq msgBody @@ -1521,7 +1523,7 @@ retrySndOp :: AgentClient -> AM () -> AM () retrySndOp c loop = do -- end... is in a separate atomically because if begin... blocks, SUSPENDED won't be sent atomically $ endAgentOperation c AOSndNetwork - atomically $ throwWhenInactive c + liftIO $ throwWhenInactive c atomically $ beginAgentOperation c AOSndNetwork loop @@ -2026,7 +2028,7 @@ deleteNtfSubs c deleteCmd = do sendNtfConnCommands :: AgentClient -> NtfSupervisorCommand -> AM () sendNtfConnCommands c cmd = do ns <- asks ntfSupervisor - connIds <- atomically $ getSubscriptions c + connIds <- liftIO $ getSubscriptions c forM_ connIds $ \connId -> do withStore' c (`getConnData` connId) >>= \case Just (ConnData {enableNtfs}, _) -> @@ -2108,7 +2110,7 @@ cleanupManager c@AgentClient {subQ} = do liftIO $ threadDelay' delay int <- asks (cleanupInterval . config) ttl <- asks $ storedMsgDataTTL . config - forever $ do + forever $ waitActive $ do run ERR deleteConns run ERR $ withStore' c (`deleteRcvMsgHashesExpired` ttl) run ERR $ withStore' c (`deleteSndMsgsExpired` ttl) @@ -2128,7 +2130,8 @@ cleanupManager c@AgentClient {subQ} = do step <- asks $ cleanupStepInterval . config liftIO $ threadDelay step -- we are catching it to avoid CRITICAL errors in tests when this is the only remaining handle to active - waitActive a = liftIO (E.tryAny . atomically $ waitUntilActive c) >>= either (\_ -> pure ()) (\_ -> void a) + waitActive :: ReaderT Env IO a -> AM' () + waitActive a = liftIO (E.tryAny $ waitUntilActive c) >>= either (\_ -> pure ()) (\_ -> void a) deleteConns = withLock (deleteLock c) "cleanupManager" $ do void $ withStore' c getDeletedConnIds >>= deleteDeletedConns c @@ -2218,7 +2221,7 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId processSubOk :: RcvQueue -> TVar [ConnId] -> AM () processSubOk rq@RcvQueue {connId} upConnIds = atomically . whenM (isPendingSub connId) $ do - addSubscription c rq + addSubscription c sessId rq modifyTVar' upConnIds (connId :) processSubErr :: RcvQueue -> SMPClientError -> AM () processSubErr rq@RcvQueue {connId} e = do @@ -2253,7 +2256,7 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId ack' <- handleNotifyAck $ case msg' of SMP.ClientRcvMsgBody {msgTs = srvTs, msgFlags, msgBody} -> processClientMsg srvTs msgFlags msgBody SMP.ClientRcvMsgQuota {} -> queueDrained >> ack - whenM (atomically $ hasGetLock c rq) $ + whenM (liftIO $ hasGetLock c rq) $ notify (MSGNTF $ SMP.rcvMessageMeta srvMsgId msg') pure ack' where @@ -2492,6 +2495,18 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId confId <- withStore c $ \db -> do setConnAgentVersion db connId agentVersion when (pqSupport /= pqSupport') $ setConnPQSupport db connId pqSupport' + -- / + -- Starting with agent version 7 (ratchetOnConfSMPAgentVersion), + -- initiating party initializes ratchet on processing confirmation; + -- previously, it initialized ratchet on allowConnection; + -- this is to support decryption of messages that may be received before allowConnection + liftIO $ do + createRatchet db connId rc' + let RcvQueue {smpClientVersion = v, e2ePrivKey = e2ePrivKey'} = rq + SMPConfirmation {smpClientVersion = v', e2ePubKey = e2ePubKey'} = senderConf + dhSecret = C.dh' e2ePubKey' e2ePrivKey' + setRcvQueueConfirmedE2E db rq dhSecret $ min v v' + -- / createConfirmation db g newConfirmation let srvs = map qServer $ smpReplyQueues senderConf notify $ CONF confId pqSupport' srvs connInfo @@ -2775,25 +2790,27 @@ connectReplyQueues c cData@ConnData {userId, connId} ownConnInfo sq_ (qInfo :| _ Just qInfo' -> do -- in case of SKEY retry the connection is already duplex sq' <- maybe upgradeConn pure sq_ - agentSecureSndQueue c sq' + void $ agentSecureSndQueue c cData sq' enqueueConfirmation c cData sq' ownConnInfo Nothing where upgradeConn = do (sq, _) <- lift $ newSndQueue userId connId qInfo' withStore c $ \db -> upgradeRcvConnToDuplex db connId sq -secureConfirmQueueAsync :: AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> SubscriptionMode -> AM () +secureConfirmQueueAsync :: AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> SubscriptionMode -> AM SndQueueSecured secureConfirmQueueAsync c cData sq srv connInfo e2eEncryption_ subMode = do - agentSecureSndQueue c sq + sqSecured <- agentSecureSndQueue c cData sq storeConfirmation c cData sq e2eEncryption_ =<< mkAgentConfirmation c cData sq srv connInfo subMode lift $ submitPendingMsg c cData sq + pure sqSecured -secureConfirmQueue :: AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> SubscriptionMode -> AM () +secureConfirmQueue :: AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> SubscriptionMode -> AM SndQueueSecured secureConfirmQueue c cData@ConnData {connId, connAgentVersion, pqSupport} sq srv connInfo e2eEncryption_ subMode = do - agentSecureSndQueue c sq + sqSecured <- agentSecureSndQueue c cData sq msg <- mkConfirmation =<< mkAgentConfirmation c cData sq srv connInfo subMode void $ sendConfirmation c sq msg withStore' c $ \db -> setSndQueueStatus db sq Confirmed + pure sqSecured where mkConfirmation :: AgentMessage -> AM MsgBody mkConfirmation aMessage = do @@ -2806,11 +2823,17 @@ secureConfirmQueue c cData@ConnData {connId, connAgentVersion, pqSupport} sq srv (encConnInfo, _) <- agentRatchetEncrypt db cData agentMsgBody e2eEncConnInfoLength (Just pqEnc) currentE2EVersion pure . smpEncode $ AgentConfirmation {agentVersion = connAgentVersion, e2eEncryption_, encConnInfo} -agentSecureSndQueue :: AgentClient -> SndQueue -> AM () -agentSecureSndQueue c sq@SndQueue {sndSecure, status} = - when (sndSecure && status == New) $ do - secureSndQueue c sq - withStore' c $ \db -> setSndQueueStatus db sq Secured +agentSecureSndQueue :: AgentClient -> ConnData -> SndQueue -> AM SndQueueSecured +agentSecureSndQueue c ConnData {connAgentVersion} sq@SndQueue {sndSecure, status} + | sndSecure && status == New = do + secureSndQueue c sq + withStore' c $ \db -> setSndQueueStatus db sq Secured + pure initiatorRatchetOnConf + -- on repeat JOIN processing (e.g. previous attempt to create reply queue failed) + | sndSecure && status == Secured = pure initiatorRatchetOnConf + | otherwise = pure False + where + initiatorRatchetOnConf = connAgentVersion >= ratchetOnConfSMPAgentVersion mkAgentConfirmation :: AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> SubscriptionMode -> AM AgentMessage mkAgentConfirmation c cData sq srv connInfo subMode = do diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 43b3b8064..02b31cb95 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -93,6 +93,7 @@ module Simplex.Messaging.Agent.Client AgentServersSummary (..), ServerSessions (..), SMPServerSubs (..), + getAgentSubsTotal, getAgentServersSummary, getAgentSubscriptions, slowNetworkConfig, @@ -117,7 +118,7 @@ module Simplex.Messaging.Agent.Client waitUntilActive, UserNetworkInfo (..), UserNetworkType (..), - getNetworkConfig', + getFastNetworkConfig, waitForUserNetwork, isNetworkOnline, isOnline, @@ -126,6 +127,7 @@ module Simplex.Messaging.Agent.Client beginAgentOperation, endAgentOperation, waitUntilForeground, + waitWhileSuspended, suspendSendingAndDatabase, suspendOperation, notifySuspended, @@ -145,6 +147,7 @@ module Simplex.Messaging.Agent.Client incXFTPServerStat, incXFTPServerStat', incXFTPServerSizeStat, + incNtfServerStat, AgentWorkersDetails (..), getAgentWorkersDetails, AgentWorkersSummary (..), @@ -162,7 +165,7 @@ where import Control.Applicative ((<|>)) import Control.Concurrent (ThreadId, forkIO) import Control.Concurrent.Async (Async, uninterruptibleCancel) -import Control.Concurrent.STM (retry, throwSTM) +import Control.Concurrent.STM (retry) import Control.Exception (AsyncException (..), BlockedIndefinitelyOnSTM (..)) import Control.Logger.Simple import Control.Monad @@ -304,13 +307,12 @@ data AgentClient = AgentClient userNetworkInfo :: TVar UserNetworkInfo, userNetworkUpdated :: TVar (Maybe UTCTime), subscrConns :: TVar (Set ConnId), - activeSubs :: TRcvQueues, - pendingSubs :: TRcvQueues, + activeSubs :: TRcvQueues (SessionId, RcvQueue), + pendingSubs :: TRcvQueues RcvQueue, removedSubs :: TMap (UserId, SMPServer, SMP.RecipientId) SMPClientError, workerSeq :: TVar Int, smpDeliveryWorkers :: TMap SndQAddr (Worker, TMVar ()), asyncCmdWorkers :: TMap (Maybe SMPServer) Worker, - connCmdsQueued :: TMap ConnId Bool, ntfNetworkOp :: TVar AgentOpState, rcvNetworkOp :: TVar AgentOpState, msgDeliveryOp :: TVar AgentOpState, @@ -330,6 +332,7 @@ data AgentClient = AgentClient agentEnv :: Env, smpServersStats :: TMap (UserId, SMPServer) AgentSMPServerStats, xftpServersStats :: TMap (UserId, XFTPServer) AgentXFTPServerStats, + ntfServersStats :: TMap (UserId, NtfServer) AgentNtfServerStats, srvStatsStartedAt :: TVar UTCTime } @@ -368,13 +371,15 @@ getAgentWorker' toW fromW name hasWork c key ws work = do restart <- atomically $ getWorker >>= maybe (pure False) (shouldRestart e_ (toW w) t maxRestarts) when restart runWork shouldRestart e_ Worker {workerId = wId, doWork, action, restarts} t maxRestarts w' - | wId == workerId (toW w') = - checkRestarts . updateRestartCount t =<< readTVar restarts + | wId == workerId (toW w') = do + rc <- readTVar restarts + isActive <- readTVar $ active c + checkRestarts isActive $ updateRestartCount t rc | otherwise = pure False -- there is a new worker in the map, no action where - checkRestarts rc - | restartCount rc < maxRestarts = do + checkRestarts isActive rc + | isActive && restartCount rc < maxRestarts = do writeTVar restarts rc hasWorkToDo' doWork void $ tryPutTMVar action Nothing @@ -382,7 +387,7 @@ getAgentWorker' toW fromW name hasWork c key ws work = do pure True | otherwise = do TM.delete key ws - notifyErr $ CRITICAL True + when isActive $ notifyErr $ CRITICAL True pure False where notifyErr err = do @@ -449,46 +454,46 @@ data UserNetworkType = UNNone | UNCellular | UNWifi | UNEthernet | UNOther deriving (Eq, Show) -- | Creates an SMP agent client instance that receives commands and sends responses via 'TBQueue's. -newAgentClient :: Int -> InitialAgentServers -> UTCTime -> Env -> STM AgentClient +newAgentClient :: Int -> InitialAgentServers -> UTCTime -> Env -> IO AgentClient newAgentClient clientId InitialAgentServers {smp, ntf, xftp, netCfg} currentTs agentEnv = do let cfg = config agentEnv qSize = tbqSize cfg - acThread <- newTVar Nothing - active <- newTVar True - subQ <- newTBQueue qSize - msgQ <- newTBQueue qSize - smpServers <- newTVar $ M.map mkUserServers smp - smpClients <- TM.empty - smpProxiedRelays <- TM.empty - ntfServers <- newTVar ntf - ntfClients <- TM.empty - xftpServers <- newTVar $ M.map mkUserServers xftp - xftpClients <- TM.empty - useNetworkConfig <- newTVar (slowNetworkConfig netCfg, netCfg) - userNetworkInfo <- newTVar $ UserNetworkInfo UNOther True - userNetworkUpdated <- newTVar Nothing - subscrConns <- newTVar S.empty + acThread <- newTVarIO Nothing + active <- newTVarIO True + subQ <- newTBQueueIO qSize + msgQ <- newTBQueueIO qSize + smpServers <- newTVarIO $ M.map mkUserServers smp + smpClients <- TM.emptyIO + smpProxiedRelays <- TM.emptyIO + ntfServers <- newTVarIO ntf + ntfClients <- TM.emptyIO + xftpServers <- newTVarIO $ M.map mkUserServers xftp + xftpClients <- TM.emptyIO + useNetworkConfig <- newTVarIO (slowNetworkConfig netCfg, netCfg) + userNetworkInfo <- newTVarIO $ UserNetworkInfo UNOther True + userNetworkUpdated <- newTVarIO Nothing + subscrConns <- newTVarIO S.empty activeSubs <- RQ.empty pendingSubs <- RQ.empty - removedSubs <- TM.empty - workerSeq <- newTVar 0 - smpDeliveryWorkers <- TM.empty - asyncCmdWorkers <- TM.empty - connCmdsQueued <- TM.empty - ntfNetworkOp <- newTVar $ AgentOpState False 0 - rcvNetworkOp <- newTVar $ AgentOpState False 0 - msgDeliveryOp <- newTVar $ AgentOpState False 0 - sndNetworkOp <- newTVar $ AgentOpState False 0 - databaseOp <- newTVar $ AgentOpState False 0 - agentState <- newTVar ASForeground - getMsgLocks <- TM.empty - connLocks <- TM.empty - invLocks <- TM.empty - deleteLock <- createLock - smpSubWorkers <- TM.empty - smpServersStats <- TM.empty - xftpServersStats <- TM.empty - srvStatsStartedAt <- newTVar currentTs + removedSubs <- TM.emptyIO + workerSeq <- newTVarIO 0 + smpDeliveryWorkers <- TM.emptyIO + asyncCmdWorkers <- TM.emptyIO + ntfNetworkOp <- newTVarIO $ AgentOpState False 0 + rcvNetworkOp <- newTVarIO $ AgentOpState False 0 + msgDeliveryOp <- newTVarIO $ AgentOpState False 0 + sndNetworkOp <- newTVarIO $ AgentOpState False 0 + databaseOp <- newTVarIO $ AgentOpState False 0 + agentState <- newTVarIO ASForeground + getMsgLocks <- TM.emptyIO + connLocks <- TM.emptyIO + invLocks <- TM.emptyIO + deleteLock <- atomically createLock + smpSubWorkers <- TM.emptyIO + smpServersStats <- TM.emptyIO + xftpServersStats <- TM.emptyIO + ntfServersStats <- TM.emptyIO + srvStatsStartedAt <- newTVarIO currentTs return AgentClient { acThread, @@ -512,7 +517,6 @@ newAgentClient clientId InitialAgentServers {smp, ntf, xftp, netCfg} currentTs a workerSeq, smpDeliveryWorkers, asyncCmdWorkers, - connCmdsQueued, ntfNetworkOp, rcvNetworkOp, msgDeliveryOp, @@ -528,6 +532,7 @@ newAgentClient clientId InitialAgentServers {smp, ntf, xftp, netCfg} currentTs a agentEnv, smpServersStats, xftpServersStats, + ntfServersStats, srvStatsStartedAt } @@ -594,7 +599,7 @@ getSMPServerClient c@AgentClient {active, smpClients, workerSeq} tSess = do >>= either newClient (waitForProtocolClient c tSess smpClients) where newClient v = do - prs <- atomically TM.empty + prs <- liftIO TM.emptyIO smpConnectClient c tSess prs v getSMPProxyClient :: AgentClient -> Maybe SMPServerWithAuth -> SMPTransportSession -> AM (SMPConnectedClient, Either AgentErrorType ProxiedRelay) @@ -612,11 +617,10 @@ getSMPProxyClient c@AgentClient {active, smpClients, smpProxiedRelays, workerSeq (tSess,auth,) <$> getSessVar workerSeq tSess smpClients ts newProxyClient :: SMPTransportSession -> Maybe SMP.BasicAuth -> UTCTime -> SMPClientVar -> AM (SMPConnectedClient, Either AgentErrorType ProxiedRelay) newProxyClient tSess auth ts v = do - (prs, rv) <- atomically $ do - prs <- TM.empty - -- we do not need to check if it is a new proxied relay session, - -- as the client is just created and there are no sessions yet - (prs,) . either id id <$> getSessVar workerSeq destSrv prs ts + prs <- liftIO TM.emptyIO + -- we do not need to check if it is a new proxied relay session, + -- as the client is just created and there are no sessions yet + rv <- atomically $ either id id <$> getSessVar workerSeq destSrv prs ts clnt <- smpConnectClient c tSess prs v (clnt,) <$> newProxiedRelay clnt auth rv waitForProxyClient :: SMPTransportSession -> Maybe SMP.BasicAuth -> SMPClientVar -> AM (SMPConnectedClient, Either AgentErrorType ProxiedRelay) @@ -642,7 +646,7 @@ getSMPProxyClient c@AgentClient {active, smpClients, smpProxiedRelays, workerSeq pure $ Left e waitForProxiedRelay :: SMPTransportSession -> ProxiedRelayVar -> AM (Either AgentErrorType ProxiedRelay) waitForProxiedRelay (_, srv, _) rv = do - NetworkConfig {tcpConnectTimeout} <- atomically $ getNetworkConfig c + NetworkConfig {tcpConnectTimeout} <- getNetworkConfig c sess_ <- liftIO $ tcpConnectTimeout `timeout` atomically (readTMVar $ sessionVar rv) pure $ case sess_ of Just (Right sess) -> Right sess @@ -672,11 +676,13 @@ smpClientDisconnected c@AgentClient {active, smpClients, smpProxiedRelays} tSess -- because we can have a race condition when a new current client could have already -- made subscriptions active, and the old client would be processing diconnection later. removeClientAndSubs :: IO ([RcvQueue], [ConnId]) - removeClientAndSubs = atomically $ ifM currentActiveClient removeSubs $ pure ([], []) + removeClientAndSubs = atomically $ do + removeSessVar v tSess smpClients + ifM (readTVar active) removeSubs (pure ([], [])) where - currentActiveClient = (&&) <$> removeSessVar' v tSess smpClients <*> readTVar active + sessId = sessionId $ thParams client removeSubs = do - (qs, cs) <- RQ.getDelSessQueues tSess $ activeSubs c + (qs, cs) <- RQ.getDelSessQueues tSess sessId $ activeSubs c RQ.batchAddQueues (pendingSubs c) qs -- this removes proxied relays that this client created sessions to destSrvs <- M.keys <$> readTVar prs @@ -701,7 +707,7 @@ resubscribeSMPSession c@AgentClient {smpSubWorkers, workerSeq} tSess = do where getWorkerVar ts = ifM - (null <$> getPending) + (not <$> RQ.hasSessQueues tSess (pendingSubs c)) (pure Nothing) -- prevent race with cleanup and adding pending queues in another call (Just <$> getSessVar workerSeq tSess smpSubWorkers ts) newSubWorker v = do @@ -709,13 +715,14 @@ resubscribeSMPSession c@AgentClient {smpSubWorkers, workerSeq} tSess = do atomically $ putTMVar (sessionVar v) a runSubWorker = do ri <- asks $ reconnectInterval . config - withRetryInterval ri $ \_ loop -> do - pending <- atomically getPending + withRetryForeground ri isForeground (isNetworkOnline c) $ \_ loop -> do + pending <- liftIO $ RQ.getSessQueues tSess $ pendingSubs c forM_ (L.nonEmpty pending) $ \qs -> do + liftIO $ waitUntilForeground c liftIO $ waitForUserNetwork c reconnectSMPClient c tSess qs loop - getPending = RQ.getSessQueues tSess $ pendingSubs c + isForeground = (ASForeground ==) <$> readTVar (agentState c) cleanup :: SessionVar (Async ()) -> STM () cleanup v = do -- Here we wait until TMVar is not empty to prevent worker cleanup happening before worker is added to TMVar. @@ -780,7 +787,7 @@ getXFTPServerClient c@AgentClient {active, xftpClients, workerSeq} tSess@(_, srv connectClient :: XFTPClientVar -> AM XFTPClient connectClient v = do cfg <- asks $ xftpCfg . config - xftpNetworkConfig <- atomically $ getNetworkConfig c + xftpNetworkConfig <- getNetworkConfig c liftError' (protocolClientError XFTP $ B.unpack $ strEncode srv) $ X.getXFTPClient tSess cfg {xftpNetworkConfig} $ clientDisconnected v @@ -799,7 +806,7 @@ waitForProtocolClient :: ClientVar msg -> AM (Client msg) waitForProtocolClient c tSess@(_, srv, _) clients v = do - NetworkConfig {tcpConnectTimeout} <- atomically $ getNetworkConfig c + NetworkConfig {tcpConnectTimeout} <- getNetworkConfig c client_ <- liftIO $ tcpConnectTimeout `timeout` atomically (readTMVar $ sessionVar v) case client_ of Just (Right smpClient) -> pure smpClient @@ -850,26 +857,26 @@ hostEvent' event = event (AProtocolType $ protocolTypeI @(ProtoType msg)) . clie getClientConfig :: AgentClient -> (AgentConfig -> ProtocolClientConfig v) -> AM' (ProtocolClientConfig v) getClientConfig c cfgSel = do cfg <- asks $ cfgSel . config - networkConfig <- atomically $ getNetworkConfig c + networkConfig <- getNetworkConfig c pure cfg {networkConfig} -getNetworkConfig :: AgentClient -> STM NetworkConfig +getNetworkConfig :: MonadIO m => AgentClient -> m NetworkConfig getNetworkConfig c = do - (slowCfg, fastCfg) <- readTVar (useNetworkConfig c) - UserNetworkInfo {networkType} <- readTVar $ userNetworkInfo c + (slowCfg, fastCfg) <- readTVarIO $ useNetworkConfig c + UserNetworkInfo {networkType} <- readTVarIO $ userNetworkInfo c pure $ case networkType of UNCellular -> slowCfg UNNone -> slowCfg _ -> fastCfg -- returns fast network config -getNetworkConfig' :: AgentClient -> IO NetworkConfig -getNetworkConfig' = fmap snd . readTVarIO . useNetworkConfig -{-# INLINE getNetworkConfig' #-} +getFastNetworkConfig :: AgentClient -> IO NetworkConfig +getFastNetworkConfig = fmap snd . readTVarIO . useNetworkConfig +{-# INLINE getFastNetworkConfig #-} waitForUserNetwork :: AgentClient -> IO () waitForUserNetwork c = - unlessM (atomically $ isNetworkOnline c) $ do + unlessM (isOnline <$> readTVarIO (userNetworkInfo c)) $ do delay <- registerDelay $ userNetworkInterval $ config $ agentEnv c atomically $ unlessM (isNetworkOnline c) $ unlessM (readTVar delay) retry @@ -883,7 +890,6 @@ closeAgentClient c = do atomically (swapTVar (smpSubWorkers c) M.empty) >>= mapM_ cancelReconnect clearWorkers smpDeliveryWorkers >>= mapM_ (cancelWorker . fst) clearWorkers asyncCmdWorkers >>= mapM_ cancelWorker - clear connCmdsQueued atomically . RQ.clear $ activeSubs c atomically . RQ.clear $ pendingSubs c clear subscrConns @@ -901,19 +907,18 @@ cancelWorker Worker {doWork, action} = do noWorkToDo doWork atomically (tryTakeTMVar action) >>= mapM_ (mapM_ uninterruptibleCancel) -waitUntilActive :: AgentClient -> STM () -waitUntilActive c = unlessM (readTVar $ active c) retry -{-# INLINE waitUntilActive #-} +waitUntilActive :: AgentClient -> IO () +waitUntilActive AgentClient {active} = unlessM (readTVarIO active) $ atomically $ unlessM (readTVar active) retry -throwWhenInactive :: AgentClient -> STM () -throwWhenInactive c = unlessM (readTVar $ active c) $ throwSTM ThreadKilled +throwWhenInactive :: AgentClient -> IO () +throwWhenInactive c = unlessM (readTVarIO $ active c) $ E.throwIO ThreadKilled {-# INLINE throwWhenInactive #-} -- this function is used to remove workers once delivery is complete, not when it is removed from the map -throwWhenNoDelivery :: AgentClient -> SndQueue -> STM () +throwWhenNoDelivery :: AgentClient -> SndQueue -> IO () throwWhenNoDelivery c sq = - unlessM (TM.member (qAddress sq) $ smpDeliveryWorkers c) $ - throwSTM ThreadKilled + unlessM (TM.memberIO (qAddress sq) $ smpDeliveryWorkers c) $ + E.throwIO ThreadKilled closeProtocolServerClients :: ProtocolServerClient v err msg => AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar msg)) -> IO () closeProtocolServerClients c clientsSel = @@ -939,7 +944,7 @@ closeClient c clientSel tSess = closeClient_ :: ProtocolServerClient v err msg => AgentClient -> ClientVar msg -> IO () closeClient_ c v = do - NetworkConfig {tcpConnectTimeout} <- atomically $ getNetworkConfig c + NetworkConfig {tcpConnectTimeout} <- getNetworkConfig c E.handle (\BlockedIndefinitelyOnSTM -> pure ()) $ tcpConnectTimeout `timeout` atomically (readTMVar $ sessionVar v) >>= \case Just (Right client) -> closeProtocolServerClient (protocolClient client) `catchAll_` pure () @@ -1027,7 +1032,7 @@ withLogClient c tSess entId cmdStr action = withLogClient_ c tSess entId cmdStr withSMPClient :: SMPQueueRec q => AgentClient -> q -> ByteString -> (SMPClient -> ExceptT SMPClientError IO a) -> AM a withSMPClient c q cmdStr action = do - tSess <- liftIO $ mkSMPTransportSession c q + tSess <- mkSMPTransportSession c q withLogClient c tSess (queueId q) cmdStr $ action . connectedClient sendOrProxySMPMessage :: AgentClient -> UserId -> SMPServer -> ByteString -> Maybe SMP.SndPrivateAuthKey -> SMP.SenderId -> MsgFlags -> SMP.MsgBody -> AM (Maybe SMPServer) @@ -1052,8 +1057,8 @@ sendOrProxySMPCommand :: (SMPClient -> ExceptT SMPClientError IO ()) -> AM (Maybe SMPServer) sendOrProxySMPCommand c userId destSrv cmdStr senderId sendCmdViaProxy sendCmdDirectly = do - sess <- liftIO $ mkTransportSession c userId destSrv senderId - ifM (atomically shouldUseProxy) (sendViaProxy Nothing sess) (sendDirectly sess $> Nothing) + sess <- mkTransportSession c userId destSrv senderId + ifM shouldUseProxy (sendViaProxy Nothing sess) (sendDirectly sess $> Nothing) where shouldUseProxy = do cfg <- getNetworkConfig c @@ -1070,7 +1075,7 @@ sendOrProxySMPCommand c userId destSrv cmdStr senderId sendCmdViaProxy sendCmdDi SPFAllow -> True SPFAllowProtected -> ipAddressProtected cfg destSrv SPFProhibit -> False - unknownServer = maybe True (notElem destSrv . knownSrvs) <$> TM.lookup userId (smpServers c) + unknownServer = liftIO $ maybe True (notElem destSrv . knownSrvs) <$> TM.lookupIO userId (smpServers c) sendViaProxy :: Maybe SMPServerWithAuth -> SMPTransportSession -> AM (Maybe SMPServer) sendViaProxy proxySrv_ destSess@(_, _, qId) = do r <- tryAgentError . withProxySession c proxySrv_ destSess senderId ("PFWD " <> cmdStr) $ \(SMPConnectedClient smp _, proxySess@ProxiedRelay {prBasicAuth}) -> do @@ -1116,7 +1121,7 @@ sendOrProxySMPCommand c userId destSrv cmdStr senderId sendCmdViaProxy sendCmdDi forM_ r' $ \proxySrv -> atomically $ incSMPServerStat c userId proxySrv sentProxied pure r' Left e - | serverHostError e -> ifM (atomically directAllowed) (sendDirectly destSess $> Nothing) (throwE e) + | serverHostError e -> ifM directAllowed (sendDirectly destSess $> Nothing) (throwE e) | otherwise -> throwE e sendDirectly tSess = withLogClient_ c tSess senderId ("SEND " <> cmdStr) $ \(SMPConnectedClient smp _) -> do @@ -1142,7 +1147,7 @@ withXFTPClient :: (Client msg -> ExceptT (ProtocolClientError err) IO b) -> AM b withXFTPClient c (userId, srv, entityId) cmdStr action = do - tSess <- liftIO $ mkTransportSession c userId srv entityId + tSess <- mkTransportSession c userId srv entityId withLogClient c tSess entityId cmdStr action liftClient :: (Show err, Encoding err) => (HostName -> err -> AgentErrorType) -> HostName -> ExceptT (ProtocolClientError err) IO a -> AM a @@ -1214,7 +1219,7 @@ runXFTPServerTest :: AgentClient -> UserId -> XFTPServerWithAuth -> AM' (Maybe P runXFTPServerTest c userId (ProtoServerWithAuth srv auth) = do cfg <- asks $ xftpCfg . config g <- asks random - xftpNetworkConfig <- atomically $ getNetworkConfig c + xftpNetworkConfig <- getNetworkConfig c workDir <- getXFTPWorkPath filePath <- getTempFilePath workDir rcvPath <- getTempFilePath workDir @@ -1285,7 +1290,7 @@ getXFTPWorkPath = do workDir <- readTVarIO =<< asks (xftpWorkDir . xftpAgent) maybe getTemporaryDirectory pure workDir -mkTransportSession :: AgentClient -> UserId -> ProtoServer msg -> EntityId -> IO (TransportSession msg) +mkTransportSession :: MonadIO m => AgentClient -> UserId -> ProtoServer msg -> EntityId -> m (TransportSession msg) mkTransportSession c userId srv entityId = mkTSession userId srv entityId <$> getSessionMode c {-# INLINE mkTransportSession #-} @@ -1293,7 +1298,7 @@ mkTSession :: UserId -> ProtoServer msg -> EntityId -> TransportSessionMode -> T mkTSession userId srv entityId mode = (userId, srv, if mode == TSMEntity then Just entityId else Nothing) {-# INLINE mkTSession #-} -mkSMPTransportSession :: SMPQueueRec q => AgentClient -> q -> IO SMPTransportSession +mkSMPTransportSession :: (SMPQueueRec q, MonadIO m) => AgentClient -> q -> m SMPTransportSession mkSMPTransportSession c q = mkSMPTSession q <$> getSessionMode c {-# INLINE mkSMPTransportSession #-} @@ -1301,8 +1306,8 @@ mkSMPTSession :: SMPQueueRec q => q -> TransportSessionMode -> SMPTransportSessi mkSMPTSession q = mkTSession (qUserId q) (qServer q) (qConnId q) {-# INLINE mkSMPTSession #-} -getSessionMode :: AgentClient -> IO TransportSessionMode -getSessionMode = atomically . fmap sessionMode . getNetworkConfig +getSessionMode :: MonadIO m => AgentClient -> m TransportSessionMode +getSessionMode = fmap sessionMode . getNetworkConfig {-# INLINE getSessionMode #-} newRcvQueue :: AgentClient -> UserId -> ConnId -> SMPServerWithAuth -> VersionRangeSMPC -> SubscriptionMode -> SenderCanSecure -> AM (NewRcvQueue, SMPQueueUri, SMPTransportSession, SessionId) @@ -1313,7 +1318,7 @@ newRcvQueue c userId connId (ProtoServerWithAuth srv auth) vRange subMode sender (dhKey, privDhKey) <- atomically $ C.generateKeyPair g (e2eDhKey, e2ePrivKey) <- atomically $ C.generateKeyPair g logServer "-->" c srv "" "NEW" - tSess <- liftIO $ mkTransportSession c userId srv connId + tSess <- mkTransportSession c userId srv connId (sessId, QIK {rcvId, sndId, rcvPublicDhKey, sndSecure}) <- withClient c tSess $ \(SMPConnectedClient smp _) -> (sessionId $ thParams smp,) <$> createSMPQueue smp rKeys dhKey auth subMode senderCanSecure @@ -1342,8 +1347,8 @@ newRcvQueue c userId connId (ProtoServerWithAuth srv auth) vRange subMode sender qUri = SMPQueueUri vRange $ SMPQueueAddress srv sndId e2eDhKey sndSecure pure (rq, qUri, tSess, sessId) -processSubResult :: AgentClient -> RcvQueue -> Either SMPClientError () -> STM () -processSubResult c rq@RcvQueue {userId, server, connId} = \case +processSubResult :: AgentClient -> SessionId -> RcvQueue -> Either SMPClientError () -> STM () +processSubResult c sessId rq@RcvQueue {userId, server, connId} = \case Left e -> unless (temporaryClientError e) $ do incSMPServerStat c userId server connSubErrs @@ -1351,7 +1356,7 @@ processSubResult c rq@RcvQueue {userId, server, connId} = \case Right () -> ifM (hasPendingSubscription c connId) - (incSMPServerStat c userId server connSubscribed >> addSubscription c rq) + (incSMPServerStat c userId server connSubscribed >> addSubscription c sessId rq) (incSMPServerStat c userId server connSubIgnored) temporaryAgentError :: AgentErrorType -> Bool @@ -1399,7 +1404,7 @@ subscribeQueues c qs = do (errs <> rs,) <$> readTVarIO session where checkQueue rq = do - prohibited <- atomically $ hasGetLock c rq + prohibited <- liftIO $ hasGetLock c rq pure $ if prohibited then Left (rq, Left $ CMD PROHIBITED "subscribeQueues") else Right rq subscribeQueues_ :: Env -> TVar (Maybe SessionId) -> SMPClient -> NonEmpty RcvQueue -> IO (BatchResponses SMPClientError ()) subscribeQueues_ env session smp qs' = do @@ -1422,7 +1427,7 @@ subscribeQueues c qs = do sessId = sessionId $ thParams smp hasTempErrors = any (either temporaryClientError (const False) . snd) processSubResults :: NonEmpty (RcvQueue, Either SMPClientError ()) -> STM () - processSubResults = mapM_ $ uncurry $ processSubResult c + processSubResults = mapM_ $ uncurry $ processSubResult c sessId resubscribe = resubscribeSMPSession c tSess `runReaderT` env activeClientSession :: AgentClient -> SMPTransportSession -> SessionId -> STM Bool @@ -1440,7 +1445,7 @@ sendTSessionBatches statCmd toRQ action c qs = where batchQueues :: AM' [(SMPTransportSession, NonEmpty q)] batchQueues = do - mode <- atomically $ sessionMode <$> getNetworkConfig c + mode <- getSessionMode c pure . M.assocs $ foldl' (batch mode) M.empty qs where batch mode m q = @@ -1461,10 +1466,10 @@ sendBatch smpCmdFunc smp qs = L.zip qs <$> smpCmdFunc smp (L.map queueCreds qs) where queueCreds RcvQueue {rcvPrivateKey, rcvId} = (rcvPrivateKey, rcvId) -addSubscription :: AgentClient -> RcvQueue -> STM () -addSubscription c rq@RcvQueue {connId} = do +addSubscription :: AgentClient -> SessionId -> RcvQueue -> STM () +addSubscription c sessId rq@RcvQueue {connId} = do modifyTVar' (subscrConns c) $ S.insert connId - RQ.addQueue rq $ activeSubs c + RQ.addQueue (sessId, rq) $ activeSubs c RQ.deleteQueue rq $ pendingSubs c failSubscription :: AgentClient -> RcvQueue -> SMPClientError -> STM () @@ -1483,7 +1488,7 @@ addNewQueueSubscription c rq tSess sessId = do atomically $ ifM (activeClientSession c tSess sessId) - (True <$ addSubscription c rq) + (True <$ addSubscription c sessId rq) (False <$ addPendingSubscription c rq) unless same $ resubscribeSMPSession c tSess @@ -1501,8 +1506,8 @@ removeSubscription c connId = do RQ.deleteConn connId $ activeSubs c RQ.deleteConn connId $ pendingSubs c -getSubscriptions :: AgentClient -> STM (Set ConnId) -getSubscriptions = readTVar . subscrConns +getSubscriptions :: AgentClient -> IO (Set ConnId) +getSubscriptions = readTVarIO . subscrConns {-# INLINE getSubscriptions #-} logServer :: MonadIO m => ByteString -> AgentClient -> ProtocolServer s -> QueueId -> ByteString -> m () @@ -1601,9 +1606,9 @@ sendAck c rq@RcvQueue {rcvId, rcvPrivateKey} msgId = do ackSMPMessage smp rcvPrivateKey rcvId msgId atomically $ releaseGetLock c rq -hasGetLock :: AgentClient -> RcvQueue -> STM Bool +hasGetLock :: AgentClient -> RcvQueue -> IO Bool hasGetLock c RcvQueue {server, rcvId} = - TM.member (server, rcvId) $ getMsgLocks c + TM.memberIO (server, rcvId) $ getMsgLocks c releaseGetLock :: AgentClient -> RcvQueue -> STM () releaseGetLock c RcvQueue {server, rcvId} = @@ -1702,7 +1707,7 @@ agentXFTPNewChunk c SndFileChunk {userId, chunkSpec = XFTPChunkSpec {chunkSize}, (sndKey, replicaKey) <- atomically . C.generateAuthKeyPair C.SEd25519 =<< asks random let fileInfo = FileInfo {sndKey, size = chunkSize, digest = chunkDigest} logServer "-->" c srv "" "FNEW" - tSess <- liftIO $ mkTransportSession c userId srv chunkDigest + tSess <- mkTransportSession c userId srv chunkDigest (sndId, rIds) <- withClient c tSess $ \xftp -> X.createXFTPChunk xftp replicaKey fileInfo (L.map fst rKeys) auth logServer "<--" c srv "" $ B.unwords ["SIDS", logSecret sndId] pure NewSndChunkReplica {server = srv, replicaId = ChunkReplicaId sndId, replicaKey, rcvIdsKeys = L.toList $ xftpRcvIdsKeys rIds rKeys} @@ -1855,16 +1860,28 @@ beginAgentOperation c op = do -- unsafeIOToSTM $ putStrLn $ "beginOperation! " <> show op <> " " <> show (opsInProgress s + 1) writeTVar opVar $! s {opsInProgress = opsInProgress s + 1} -agentOperationBracket :: MonadUnliftIO m => AgentClient -> AgentOperation -> (AgentClient -> STM ()) -> m a -> m a +agentOperationBracket :: MonadUnliftIO m => AgentClient -> AgentOperation -> (AgentClient -> IO ()) -> m a -> m a agentOperationBracket c op check action = E.bracket - (atomically (check c) >> atomically (beginAgentOperation c op)) + (liftIO (check c) >> atomically (beginAgentOperation c op)) (\_ -> atomically $ endAgentOperation c op) (const action) -waitUntilForeground :: AgentClient -> STM () -waitUntilForeground c = unlessM ((ASForeground ==) <$> readTVar (agentState c)) retry -{-# INLINE waitUntilForeground #-} +waitUntilForeground :: AgentClient -> IO () +waitUntilForeground c = + unlessM (foreground readTVarIO) $ atomically $ unlessM (foreground readTVar) retry + where + foreground :: Monad m => (TVar AgentState -> m AgentState) -> m Bool + foreground rd = (ASForeground ==) <$> rd (agentState c) + +-- This function waits while agent is suspended, but will proceed while it is suspending, +-- to allow completing in-flight operations. +waitWhileSuspended :: AgentClient -> IO () +waitWhileSuspended c = + whenM (suspended readTVarIO) $ atomically $ whenM (suspended readTVar) retry + where + suspended :: Monad m => (TVar AgentState -> m AgentState) -> m Bool + suspended rd = (ASSuspended ==) <$> rd (agentState c) withStore' :: AgentClient -> (DB.Connection -> IO a) -> AM a withStore' c action = withStore c $ fmap Right . action @@ -1935,7 +1952,7 @@ getNextServer c userId usedSrvs = withUserServers c userId $ \srvs -> withUserServers :: forall p a. (ProtocolTypeI p, UserProtocol p) => AgentClient -> UserId -> (NonEmpty (ProtoServerWithAuth p) -> AM a) -> AM a withUserServers c userId action = - atomically (TM.lookup userId $ userServers c) >>= \case + liftIO (TM.lookupIO userId $ userServers c) >>= \case Just srvs -> action $ enabledSrvs srvs _ -> throwE $ INTERNAL "unknown userId - no user servers" @@ -1943,24 +1960,17 @@ withNextSrv :: forall p a. (ProtocolTypeI p, UserProtocol p) => AgentClient -> U withNextSrv c userId usedSrvs initUsed action = do used <- readTVarIO usedSrvs srvAuth@(ProtoServerWithAuth srv _) <- getNextServer c userId used - atomically $ do - srvs_ <- TM.lookup userId $ userServers c - let unused = maybe [] ((\\ used) . map protoServer . L.toList . enabledSrvs) srvs_ - used' = if null unused then initUsed else srv : used - writeTVar usedSrvs $! used' + srvs_ <- liftIO $ TM.lookupIO userId $ userServers c + let unused = maybe [] ((\\ used) . map protoServer . L.toList . enabledSrvs) srvs_ + used' = if null unused then initUsed else srv : used + atomically $ writeTVar usedSrvs $! used' action srvAuth incSMPServerStat :: AgentClient -> UserId -> SMPServer -> (AgentSMPServerStats -> TVar Int) -> STM () incSMPServerStat c userId srv sel = incSMPServerStat' c userId srv sel 1 incSMPServerStat' :: AgentClient -> UserId -> SMPServer -> (AgentSMPServerStats -> TVar Int) -> Int -> STM () -incSMPServerStat' AgentClient {smpServersStats} userId srv sel n = do - TM.lookup (userId, srv) smpServersStats >>= \case - Just v -> modifyTVar' (sel v) (+ n) - Nothing -> do - newStats <- newAgentSMPServerStats - modifyTVar' (sel newStats) (+ n) - TM.insert (userId, srv) newStats smpServersStats +incSMPServerStat' = incServerStat (\AgentClient {smpServersStats = s} -> s) newAgentSMPServerStats incXFTPServerStat :: AgentClient -> UserId -> XFTPServer -> (AgentXFTPServerStats -> TVar Int) -> STM () incXFTPServerStat c userId srv sel = incXFTPServerStat_ c userId srv sel 1 @@ -1975,24 +1985,34 @@ incXFTPServerSizeStat = incXFTPServerStat_ {-# INLINE incXFTPServerSizeStat #-} incXFTPServerStat_ :: Num n => AgentClient -> UserId -> XFTPServer -> (AgentXFTPServerStats -> TVar n) -> n -> STM () -incXFTPServerStat_ AgentClient {xftpServersStats} userId srv sel n = do - TM.lookup (userId, srv) xftpServersStats >>= \case +incXFTPServerStat_ = incServerStat (\AgentClient {xftpServersStats = s} -> s) newAgentXFTPServerStats +{-# INLINE incXFTPServerStat_ #-} + +incNtfServerStat :: AgentClient -> UserId -> NtfServer -> (AgentNtfServerStats -> TVar Int) -> STM () +incNtfServerStat c userId srv sel = incServerStat (\AgentClient {ntfServersStats = s} -> s) newAgentNtfServerStats c userId srv sel 1 +{-# INLINE incNtfServerStat #-} + +incServerStat :: Num n => (AgentClient -> TMap (UserId, ProtocolServer p) s) -> STM s -> AgentClient -> UserId -> ProtocolServer p -> (s -> TVar n) -> n -> STM () +incServerStat statsSel mkNewStats c userId srv sel n = do + TM.lookup (userId, srv) (statsSel c) >>= \case Just v -> modifyTVar' (sel v) (+ n) Nothing -> do - newStats <- newAgentXFTPServerStats + newStats <- mkNewStats modifyTVar' (sel newStats) (+ n) - TM.insert (userId, srv) newStats xftpServersStats + TM.insert (userId, srv) newStats (statsSel c) data AgentServersSummary = AgentServersSummary { smpServersStats :: Map (UserId, SMPServer) AgentSMPServerStatsData, xftpServersStats :: Map (UserId, XFTPServer) AgentXFTPServerStatsData, + ntfServersStats :: Map (UserId, NtfServer) AgentNtfServerStatsData, statsStartedAt :: UTCTime, smpServersSessions :: Map (UserId, SMPServer) ServerSessions, smpServersSubs :: Map (UserId, SMPServer) SMPServerSubs, xftpServersSessions :: Map (UserId, XFTPServer) ServerSessions, xftpRcvInProgress :: [XFTPServer], xftpSndInProgress :: [XFTPServer], - xftpDelInProgress :: [XFTPServer] + xftpDelInProgress :: [XFTPServer], + ntfServersSessions :: Map (UserId, NtfServer) ServerSessions } deriving (Show) @@ -2009,10 +2029,30 @@ data ServerSessions = ServerSessions } deriving (Show) +getAgentSubsTotal :: AgentClient -> [UserId] -> IO (SMPServerSubs, Bool) +getAgentSubsTotal c userIds = do + ssActive <- getSubsCount activeSubs + ssPending <- getSubsCount pendingSubs + sess <- hasSession . M.toList =<< readTVarIO (smpClients c) + pure (SMPServerSubs {ssActive, ssPending}, sess) + where + getSubsCount :: (AgentClient -> TRcvQueues q) -> IO Int + getSubsCount subs = M.foldrWithKey' addSub 0 <$> readTVarIO (getRcvQueues $ subs c) + addSub :: (UserId, SMPServer, SMP.RecipientId) -> q -> Int -> Int + addSub (userId, _, _) _ cnt = if userId `elem` userIds then cnt + 1 else cnt + hasSession :: [(SMPTransportSession, SMPClientVar)] -> IO Bool + hasSession = \case + [] -> pure False + (s : ss) -> ifM (isConnected s) (pure True) (hasSession ss) + isConnected ((userId, _, _), SessionVar {sessionVar}) + | userId `elem` userIds = atomically $ maybe False isRight <$> tryReadTMVar sessionVar + | otherwise = pure False + getAgentServersSummary :: AgentClient -> IO AgentServersSummary -getAgentServersSummary c@AgentClient {smpServersStats, xftpServersStats, srvStatsStartedAt, agentEnv} = do +getAgentServersSummary c@AgentClient {smpServersStats, xftpServersStats, ntfServersStats, srvStatsStartedAt, agentEnv} = do sss <- mapM getAgentSMPServerStats =<< readTVarIO smpServersStats xss <- mapM getAgentXFTPServerStats =<< readTVarIO xftpServersStats + nss <- mapM getAgentNtfServerStats =<< readTVarIO ntfServersStats statsStartedAt <- readTVarIO srvStatsStartedAt smpServersSessions <- countSessions =<< readTVarIO (smpClients c) smpServersSubs <- getServerSubs @@ -2020,17 +2060,20 @@ getAgentServersSummary c@AgentClient {smpServersStats, xftpServersStats, srvStat xftpRcvInProgress <- catMaybes <$> getXFTPWorkerSrvs xftpRcvWorkers xftpSndInProgress <- catMaybes <$> getXFTPWorkerSrvs xftpSndWorkers xftpDelInProgress <- getXFTPWorkerSrvs xftpDelWorkers + ntfServersSessions <- countSessions =<< readTVarIO (ntfClients c) pure AgentServersSummary { smpServersStats = sss, xftpServersStats = xss, + ntfServersStats = nss, statsStartedAt, smpServersSessions, smpServersSubs, xftpServersSessions, xftpRcvInProgress, xftpSndInProgress, - xftpDelInProgress + xftpDelInProgress, + ntfServersSessions } where getServerSubs = do @@ -2076,6 +2119,7 @@ getAgentSubscriptions c = do removedSubscriptions <- getRemovedSubs pure $ SubscriptionsInfo {activeSubscriptions, pendingSubscriptions, removedSubscriptions} where + getSubs :: (AgentClient -> TRcvQueues q) -> IO [SubInfo] getSubs sel = map (`subInfo` Nothing) . M.keys <$> readTVarIO (getRcvQueues $ sel c) getRemovedSubs = map (uncurry subInfo . second Just) . M.assocs <$> readTVarIO (removedSubs c) subInfo :: (UserId, SMPServer, SMP.RecipientId) -> Maybe SMPClientError -> SubInfo diff --git a/src/Simplex/Messaging/Agent/Env/SQLite.hs b/src/Simplex/Messaging/Agent/Env/SQLite.hs index 0f88508b9..f57cf91e9 100644 --- a/src/Simplex/Messaging/Agent/Env/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Env/SQLite.hs @@ -51,7 +51,7 @@ import Data.ByteArray (ScrubbedBytes) import Data.Int (Int64) import Data.List.NonEmpty (NonEmpty) import qualified Data.List.NonEmpty as L -import Data.Map (Map) +import Data.Map.Strict (Map) import Data.Maybe (fromMaybe) import Data.Time.Clock (NominalDiffTime, nominalDay) import Data.Time.Clock.System (SystemTime (..)) @@ -148,10 +148,7 @@ data AgentConfig = AgentConfig xftpMaxRecipientsPerRequest :: Int, deleteErrorCount :: Int, ntfCron :: Word16, - ntfWorkerDelay :: Int, - ntfSMPWorkerDelay :: Int, ntfSubCheckInterval :: NominalDiffTime, - ntfMaxMessages :: Int, caCertificateFile :: FilePath, privateKeyFile :: FilePath, certificateFile :: FilePath, @@ -165,7 +162,7 @@ defaultReconnectInterval = RetryInterval { initialInterval = 2_000000, increaseAfter = 10_000000, - maxInterval = 60_000000 + maxInterval = 180_000000 } defaultMessageRetryInterval :: RetryInterval2 @@ -175,7 +172,7 @@ defaultMessageRetryInterval = RetryInterval { initialInterval = 2_000000, increaseAfter = 10_000000, - maxInterval = 60_000000 + maxInterval = 120_000000 }, riSlow = RetryInterval @@ -220,10 +217,7 @@ defaultAgentConfig = xftpMaxRecipientsPerRequest = 200, deleteErrorCount = 10, ntfCron = 20, -- minutes - ntfWorkerDelay = 100000, -- microseconds - ntfSMPWorkerDelay = 500000, -- microseconds ntfSubCheckInterval = nominalDay, - ntfMaxMessages = 3, -- CA certificate private key is not needed for initialization -- ! we do not generate these caCertificateFile = "/etc/opt/simplex-agent/ca.crt", @@ -248,8 +242,8 @@ newSMPAgentEnv :: AgentConfig -> SQLiteStore -> IO Env newSMPAgentEnv config store = do random <- C.newRandom randomServer <- newTVarIO =<< liftIO newStdGen - ntfSupervisor <- atomically . newNtfSubSupervisor $ tbqSize config - xftpAgent <- atomically newXFTPAgent + ntfSupervisor <- newNtfSubSupervisor $ tbqSize config + xftpAgent <- newXFTPAgent multicastSubscribers <- newTMVarIO 0 pure Env {config, store, random, randomServer, ntfSupervisor, xftpAgent, multicastSubscribers} @@ -266,12 +260,12 @@ data NtfSupervisor = NtfSupervisor data NtfSupervisorCommand = NSCCreate | NSCDelete | NSCSmpDelete | NSCNtfWorker NtfServer | NSCNtfSMPWorker SMPServer deriving (Show) -newNtfSubSupervisor :: Natural -> STM NtfSupervisor +newNtfSubSupervisor :: Natural -> IO NtfSupervisor newNtfSubSupervisor qSize = do - ntfTkn <- newTVar Nothing - ntfSubQ <- newTBQueue qSize - ntfWorkers <- TM.empty - ntfSMPWorkers <- TM.empty + ntfTkn <- newTVarIO Nothing + ntfSubQ <- newTBQueueIO qSize + ntfWorkers <- TM.emptyIO + ntfSMPWorkers <- TM.emptyIO pure NtfSupervisor {ntfTkn, ntfSubQ, ntfWorkers, ntfSMPWorkers} data XFTPAgent = XFTPAgent @@ -282,12 +276,12 @@ data XFTPAgent = XFTPAgent xftpDelWorkers :: TMap XFTPServer Worker } -newXFTPAgent :: STM XFTPAgent +newXFTPAgent :: IO XFTPAgent newXFTPAgent = do - xftpWorkDir <- newTVar Nothing - xftpRcvWorkers <- TM.empty - xftpSndWorkers <- TM.empty - xftpDelWorkers <- TM.empty + xftpWorkDir <- newTVarIO Nothing + xftpRcvWorkers <- TM.emptyIO + xftpSndWorkers <- TM.emptyIO + xftpDelWorkers <- TM.emptyIO pure XFTPAgent {xftpWorkDir, xftpRcvWorkers, xftpSndWorkers, xftpDelWorkers} tryAgentError :: AM a -> AM (Either AgentErrorType a) diff --git a/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs b/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs index a239768b0..23a88ea70 100644 --- a/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs +++ b/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs @@ -20,8 +20,8 @@ where import Control.Logger.Simple (logError, logInfo) import Control.Monad -import Control.Monad.Except import Control.Monad.Reader +import Control.Monad.Trans.Except import Data.Bifunctor (first) import qualified Data.Map.Strict as M import Data.Text (Text) @@ -31,6 +31,7 @@ import Simplex.Messaging.Agent.Client import Simplex.Messaging.Agent.Env.SQLite import Simplex.Messaging.Agent.Protocol (AEvent (..), AEvt (..), AgentErrorType (..), BrokerErrorType (..), ConnId, NotificationsMode (..), SAEntity (..)) import Simplex.Messaging.Agent.RetryInterval +import Simplex.Messaging.Agent.Stats import Simplex.Messaging.Agent.Store import Simplex.Messaging.Agent.Store.SQLite import qualified Simplex.Messaging.Crypto as C @@ -40,7 +41,7 @@ import Simplex.Messaging.Protocol (NtfServer, SMPServer, sameSrvAddr) import Simplex.Messaging.Util (diffToMicroseconds, threadDelay', tshow, unlessM) import System.Random (randomR) import UnliftIO -import UnliftIO.Concurrent (forkIO, threadDelay) +import UnliftIO.Concurrent (forkIO) import qualified UnliftIO.Exception as E runNtfSupervisor :: AgentClient -> AM' () @@ -64,7 +65,7 @@ processNtfSub c (connId, cmd) = do logInfo $ "processNtfSub - connId = " <> tshow connId <> " - cmd = " <> tshow cmd case cmd of NSCCreate -> do - (a, RcvQueue {server = smpServer, clientNtfCreds}) <- withStore c $ \db -> runExceptT $ do + (a, RcvQueue {userId, server = smpServer, clientNtfCreds}) <- withStore c $ \db -> runExceptT $ do a <- liftIO $ getNtfSubscription db connId q <- ExceptT $ getPrimaryRcvQueue db connId pure (a, q) @@ -74,12 +75,12 @@ processNtfSub c (connId, cmd) = do withTokenServer $ \ntfServer -> do case clientNtfCreds of Just ClientNtfCreds {notifierId} -> do - let newSub = newNtfSubscription connId smpServer (Just notifierId) ntfServer NASKey - withStore c $ \db -> createNtfSubscription db newSub $ NtfSubNTFAction NSACreate + let newSub = newNtfSubscription userId connId smpServer (Just notifierId) ntfServer NASKey + withStore c $ \db -> createNtfSubscription db newSub $ NSANtf NSACreate lift . void $ getNtfNTFWorker True c ntfServer Nothing -> do - let newSub = newNtfSubscription connId smpServer Nothing ntfServer NASNew - withStore c $ \db -> createNtfSubscription db newSub $ NtfSubSMPAction NSASmpKey + let newSub = newNtfSubscription userId connId smpServer Nothing ntfServer NASNew + withStore c $ \db -> createNtfSubscription db newSub $ NSASMP NSASmpKey lift . void $ getNtfSMPWorker True c smpServer (Just (sub@NtfSubscription {ntfSubStatus, ntfServer = subNtfServer, smpServer = smpServer', ntfQueueId}, action_)) -> do case (clientNtfCreds, ntfQueueId) of @@ -99,24 +100,24 @@ processNtfSub c (connId, cmd) = do if ntfSubStatus == NASNew || ntfSubStatus == NASOff || ntfSubStatus == NASDeleted then resetSubscription else withTokenServer $ \ntfServer -> do - withStore' c $ \db -> supervisorUpdateNtfSub db sub {ntfServer} (NtfSubNTFAction NSACreate) + withStore' c $ \db -> supervisorUpdateNtfSub db sub {ntfServer} (NSANtf NSACreate) lift . void $ getNtfNTFWorker True c ntfServer | otherwise -> case action of - NtfSubNTFAction _ -> lift . void $ getNtfNTFWorker True c subNtfServer - NtfSubSMPAction _ -> lift . void $ getNtfSMPWorker True c smpServer + NSANtf _ -> lift . void $ getNtfNTFWorker True c subNtfServer + NSASMP _ -> lift . void $ getNtfSMPWorker True c smpServer rotate :: AM () rotate = do - withStore' c $ \db -> supervisorUpdateNtfSub db sub (NtfSubNTFAction NSARotate) + withStore' c $ \db -> supervisorUpdateNtfSub db sub (NSANtf NSARotate) lift . void $ getNtfNTFWorker True c subNtfServer resetSubscription :: AM () resetSubscription = withTokenServer $ \ntfServer -> do let sub' = sub {ntfQueueId = Nothing, ntfServer, ntfSubId = Nothing, ntfSubStatus = NASNew} - withStore' c $ \db -> supervisorUpdateNtfSub db sub' (NtfSubSMPAction NSASmpKey) + withStore' c $ \db -> supervisorUpdateNtfSub db sub' (NSASMP NSASmpKey) lift . void $ getNtfSMPWorker True c smpServer NSCDelete -> do sub_ <- withStore' c $ \db -> do - supervisorUpdateNtfAction db connId (NtfSubNTFAction NSADelete) + supervisorUpdateNtfAction db connId (NSANtf NSADelete) getNtfSubscription db connId logInfo $ "processNtfSub, NSCDelete - sub_ = " <> tshow sub_ case sub_ of @@ -126,7 +127,7 @@ processNtfSub c (connId, cmd) = do withStore' c (`getPrimaryRcvQueue` connId) >>= \case Right rq@RcvQueue {server = smpServer} -> do logInfo $ "processNtfSub, NSCSmpDelete - rq = " <> tshow rq - withStore' c $ \db -> supervisorUpdateNtfAction db connId (NtfSubSMPAction NSASmpDelete) + withStore' c $ \db -> supervisorUpdateNtfAction db connId (NSASMP NSASmpDelete) lift . void $ getNtfSMPWorker True c smpServer _ -> notifyInternalError c connId "NSCSmpDelete - no rcv queue" NSCNtfWorker ntfServer -> lift . void $ getNtfNTFWorker True c ntfServer @@ -146,12 +147,10 @@ withTokenServer :: (NtfServer -> AM ()) -> AM () withTokenServer action = lift getNtfToken >>= mapM_ (\NtfToken {ntfServer} -> action ntfServer) runNtfWorker :: AgentClient -> NtfServer -> Worker -> AM () -runNtfWorker c srv Worker {doWork} = do - delay <- asks $ ntfWorkerDelay . config +runNtfWorker c srv Worker {doWork} = forever $ do waitForWork doWork ExceptT $ agentOperationBracket c AONtfNetwork throwWhenInactive $ runExceptT runNtfOperation - threadDelay delay where runNtfOperation :: AM () runNtfOperation = @@ -160,70 +159,73 @@ runNtfWorker c srv Worker {doWork} = do logInfo $ "runNtfWorker, nextSub " <> tshow nextSub ri <- asks $ reconnectInterval . config withRetryInterval ri $ \_ loop -> do + liftIO $ waitWhileSuspended c liftIO $ waitForUserNetwork c processSub nextSub `catchAgentError` retryOnError c "NtfWorker" loop (workerInternalError c connId . show) processSub :: (NtfSubscription, NtfSubNTFAction, NtfActionTs) -> AM () - processSub (sub@NtfSubscription {connId, smpServer, ntfSubId}, action, actionTs) = do + processSub (sub@NtfSubscription {userId, connId, smpServer, ntfSubId}, action, actionTs) = do ts <- liftIO getCurrentTime unlessM (lift $ rescheduleAction doWork ts actionTs) $ case action of NSACreate -> lift getNtfToken >>= \case - Just tkn@NtfToken {ntfTokenId = Just tknId, ntfTknStatus = NTActive, ntfMode = NMInstant} -> do + Just tkn@NtfToken {ntfServer, ntfTokenId = Just tknId, ntfTknStatus = NTActive, ntfMode = NMInstant} -> do RcvQueue {clientNtfCreds} <- withStore c (`getPrimaryRcvQueue` connId) case clientNtfCreds of Just ClientNtfCreds {ntfPrivateKey, notifierId} -> do + atomically $ incNtfServerStat c userId ntfServer ntfCreateAttempts nSubId <- agentNtfCreateSubscription c tknId tkn (SMPQueueNtf smpServer notifierId) ntfPrivateKey + atomically $ incNtfServerStat c userId ntfServer ntfCreated -- possible improvement: smaller retry until Active, less frequently (daily?) once Active let actionTs' = addUTCTime 30 ts withStore' c $ \db -> - updateNtfSubscription db sub {ntfSubId = Just nSubId, ntfSubStatus = NASCreated NSNew} (NtfSubNTFAction NSACheck) actionTs' + updateNtfSubscription db sub {ntfSubId = Just nSubId, ntfSubStatus = NASCreated NSNew} (NSANtf NSACheck) actionTs' _ -> workerInternalError c connId "NSACreate - no notifier queue credentials" _ -> workerInternalError c connId "NSACreate - no active token" NSACheck -> lift getNtfToken >>= \case - Just tkn -> + Just tkn@NtfToken {ntfServer} -> case ntfSubId of - Just nSubId -> + Just nSubId -> do + atomically $ incNtfServerStat c userId ntfServer ntfCheckAttempts agentNtfCheckSubscription c nSubId tkn >>= \case NSAuth -> do - lift (getNtfServer c) >>= \case - Just ntfServer -> do - withStore' c $ \db -> - updateNtfSubscription db sub {ntfServer, ntfQueueId = Nothing, ntfSubId = Nothing, ntfSubStatus = NASNew} (NtfSubSMPAction NSASmpKey) ts - ns <- asks ntfSupervisor - atomically $ writeTBQueue (ntfSubQ ns) (connId, NSCNtfSMPWorker smpServer) - _ -> workerInternalError c connId "NSACheck - failed to reset subscription, notification server not configured" + withStore' c $ \db -> + updateNtfSubscription db sub {ntfServer, ntfQueueId = Nothing, ntfSubId = Nothing, ntfSubStatus = NASNew} (NSASMP NSASmpKey) ts + ns <- asks ntfSupervisor + atomically $ writeTBQueue (ntfSubQ ns) (connId, NSCNtfSMPWorker smpServer) status -> updateSubNextCheck ts status + atomically $ incNtfServerStat c userId ntfServer ntfChecked Nothing -> workerInternalError c connId "NSACheck - no subscription ID" _ -> workerInternalError c connId "NSACheck - no active token" - NSADelete -> case ntfSubId of - Just nSubId -> - (lift getNtfToken >>= mapM_ (agentNtfDeleteSubscription c nSubId)) - `agentFinally` continueDeletion - _ -> continueDeletion - where - continueDeletion = do - let sub' = sub {ntfSubId = Nothing, ntfSubStatus = NASOff} - withStore' c $ \db -> updateNtfSubscription db sub' (NtfSubSMPAction NSASmpDelete) ts - ns <- asks ntfSupervisor - atomically $ writeTBQueue (ntfSubQ ns) (connId, NSCNtfSMPWorker smpServer) - NSARotate -> case ntfSubId of - Just nSubId -> - (lift getNtfToken >>= mapM_ (agentNtfDeleteSubscription c nSubId)) - `agentFinally` deleteCreate - _ -> deleteCreate - where - deleteCreate = do - withStore' c $ \db -> deleteNtfSubscription db connId - ns <- asks ntfSupervisor - atomically $ writeTBQueue (ntfSubQ ns) (connId, NSCCreate) + NSADelete -> + deleteNtfSub $ do + let sub' = sub {ntfSubId = Nothing, ntfSubStatus = NASOff} + withStore' c $ \db -> updateNtfSubscription db sub' (NSASMP NSASmpDelete) ts + ns <- asks ntfSupervisor + atomically $ writeTBQueue (ntfSubQ ns) (connId, NSCNtfSMPWorker smpServer) + NSARotate -> + deleteNtfSub $ do + withStore' c $ \db -> deleteNtfSubscription db connId + ns <- asks ntfSupervisor + atomically $ writeTBQueue (ntfSubQ ns) (connId, NSCCreate) where + deleteNtfSub continue = case ntfSubId of + Just nSubId -> + lift getNtfToken >>= \case + Just tkn@NtfToken {ntfServer} -> do + atomically $ incNtfServerStat c userId ntfServer ntfDelAttempts + tryAgentError (agentNtfDeleteSubscription c nSubId tkn) >>= \case + Left e | temporaryOrHostError e -> throwE e + _ -> continue + atomically $ incNtfServerStat c userId ntfServer ntfDeleted + Nothing -> continue + _ -> continue updateSubNextCheck ts toStatus = do checkInterval <- asks $ ntfSubCheckInterval . config let nextCheckTs = addUTCTime checkInterval ts - updateSub (NASCreated toStatus) (NtfSubNTFAction NSACheck) nextCheckTs + updateSub (NASCreated toStatus) (NSANtf NSACheck) nextCheckTs updateSub toStatus toAction actionTs' = withStore' c $ \db -> updateNtfSubscription db sub {ntfSubStatus = toStatus} toAction actionTs' @@ -231,12 +233,10 @@ runNtfWorker c srv Worker {doWork} = do runNtfSMPWorker :: AgentClient -> SMPServer -> Worker -> AM () runNtfSMPWorker c srv Worker {doWork} = do env <- ask - delay <- asks $ ntfSMPWorkerDelay . config forever $ do waitForWork doWork ExceptT . liftIO . agentOperationBracket c AONtfNetwork throwWhenInactive $ runReaderT (runExceptT runNtfSMPOperation) env - threadDelay delay where runNtfSMPOperation = withWork c doWork (`getNextNtfSubSMPAction` srv) $ @@ -244,6 +244,7 @@ runNtfSMPWorker c srv Worker {doWork} = do logInfo $ "runNtfSMPWorker, nextSub " <> tshow nextSub ri <- asks $ reconnectInterval . config withRetryInterval ri $ \_ loop -> do + liftIO $ waitWhileSuspended c liftIO $ waitForUserNetwork c processSub nextSub `catchAgentError` retryOnError c "NtfSMPWorker" loop (workerInternalError c connId . show) @@ -264,11 +265,12 @@ runNtfSMPWorker c srv Worker {doWork} = do let rcvNtfDhSecret = C.dh' rcvNtfSrvPubDhKey rcvNtfPrivDhKey withStore' c $ \db -> do setRcvQueueNtfCreds db connId $ Just ClientNtfCreds {ntfPublicKey, ntfPrivateKey, notifierId, rcvNtfDhSecret} - updateNtfSubscription db sub {ntfQueueId = Just notifierId, ntfSubStatus = NASKey} (NtfSubNTFAction NSACreate) ts + updateNtfSubscription db sub {ntfQueueId = Just notifierId, ntfSubStatus = NASKey} (NSANtf NSACreate) ts ns <- asks ntfSupervisor atomically $ sendNtfSubCommand ns (connId, NSCNtfWorker ntfServer) _ -> workerInternalError c connId "NSASmpKey - no active token" NSASmpDelete -> do + -- TODO should we remove it after successful removal from the server? rq_ <- withStore' c $ \db -> do setRcvQueueNtfCreds db connId Nothing getPrimaryRcvQueue db connId @@ -295,7 +297,7 @@ retryOnError c name loop done e = do where retryLoop = do atomically $ endAgentOperation c AONtfNetwork - atomically $ throwWhenInactive c + liftIO $ throwWhenInactive c atomically $ beginAgentOperation c AONtfNetwork loop diff --git a/src/Simplex/Messaging/Agent/Protocol.hs b/src/Simplex/Messaging/Agent/Protocol.hs index b123fc1ec..ea1d51a7d 100644 --- a/src/Simplex/Messaging/Agent/Protocol.hs +++ b/src/Simplex/Messaging/Agent/Protocol.hs @@ -42,6 +42,7 @@ module Simplex.Messaging.Agent.Protocol deliveryRcptsSMPAgentVersion, pqdrSMPAgentVersion, sndAuthKeySMPAgentVersion, + ratchetOnConfSMPAgentVersion, currentSMPAgentVersion, supportedSMPAgentVRange, e2eEncConnInfoLength, @@ -49,6 +50,7 @@ module Simplex.Messaging.Agent.Protocol -- * SMP agent protocol types ConnInfo, + SndQueueSecured, ACommand (..), AEvent (..), AEvt (..), @@ -153,8 +155,8 @@ import Data.Int (Int64) import Data.Kind (Type) import Data.List.NonEmpty (NonEmpty (..)) import qualified Data.List.NonEmpty as L -import Data.Map (Map) -import qualified Data.Map as M +import Data.Map.Strict (Map) +import qualified Data.Map.Strict as M import Data.Maybe (fromMaybe, isJust) import Data.Text (Text) import Data.Text.Encoding (decodeLatin1, encodeUtf8) @@ -257,11 +259,14 @@ pqdrSMPAgentVersion = VersionSMPA 5 sndAuthKeySMPAgentVersion :: VersionSMPA sndAuthKeySMPAgentVersion = VersionSMPA 6 +ratchetOnConfSMPAgentVersion :: VersionSMPA +ratchetOnConfSMPAgentVersion = VersionSMPA 7 + minSupportedSMPAgentVersion :: VersionSMPA minSupportedSMPAgentVersion = duplexHandshakeSMPAgentVersion currentSMPAgentVersion :: VersionSMPA -currentSMPAgentVersion = VersionSMPA 6 +currentSMPAgentVersion = VersionSMPA 7 supportedSMPAgentVRange :: VersionRangeSMPA supportedSMPAgentVRange = mkVersionRange minSupportedSMPAgentVersion currentSMPAgentVersion @@ -327,6 +332,8 @@ deriving instance Show AEvt type ConnInfo = ByteString +type SndQueueSecured = Bool + -- | Parameterized type for SMP agent events data AEvent (e :: AEntity) where INV :: AConnectionRequestUri -> AEvent AEConn @@ -354,6 +361,7 @@ data AEvent (e :: AEntity) where DEL_USER :: Int64 -> AEvent AENone STAT :: ConnectionStats -> AEvent AEConn OK :: AEvent AEConn + JOINED :: SndQueueSecured -> AEvent AEConn ERR :: AgentErrorType -> AEvent AEConn SUSPENDED :: AEvent AENone RFPROG :: Int64 -> Int64 -> AEvent AERcvFile @@ -422,6 +430,7 @@ data AEventTag (e :: AEntity) where DEL_USER_ :: AEventTag AENone STAT_ :: AEventTag AEConn OK_ :: AEventTag AEConn + JOINED_ :: AEventTag AEConn ERR_ :: AEventTag AEConn SUSPENDED_ :: AEventTag AENone -- XFTP commands and responses @@ -474,6 +483,7 @@ aEventTag = \case DEL_USER _ -> DEL_USER_ STAT _ -> STAT_ OK -> OK_ + JOINED _ -> JOINED_ ERR _ -> ERR_ SUSPENDED -> SUSPENDED_ RFPROG {} -> RFPROG_ diff --git a/src/Simplex/Messaging/Agent/RetryInterval.hs b/src/Simplex/Messaging/Agent/RetryInterval.hs index 00fe4039e..35fa7c5c6 100644 --- a/src/Simplex/Messaging/Agent/RetryInterval.hs +++ b/src/Simplex/Messaging/Agent/RetryInterval.hs @@ -9,6 +9,7 @@ module Simplex.Messaging.Agent.RetryInterval RI2State (..), withRetryInterval, withRetryIntervalCount, + withRetryForeground, withRetryLock2, updateRetryInterval2, nextRetryDelay, @@ -16,10 +17,11 @@ module Simplex.Messaging.Agent.RetryInterval where import Control.Concurrent (forkIO) +import Control.Concurrent.STM (retry) import Control.Monad (void) import Control.Monad.IO.Class (MonadIO, liftIO) import Data.Int (Int64) -import Simplex.Messaging.Util (threadDelay', whenM) +import Simplex.Messaging.Util (threadDelay', unlessM, whenM) import UnliftIO.STM data RetryInterval = RetryInterval @@ -63,6 +65,27 @@ withRetryIntervalCount ri action = callAction 0 0 $ initialInterval ri let elapsed' = elapsed + delay callAction (n + 1) elapsed' $ nextRetryDelay elapsed' delay ri +withRetryForeground :: forall m a. MonadIO m => RetryInterval -> STM Bool -> STM Bool -> (Int64 -> m a -> m a) -> m a +withRetryForeground ri isForeground isOnline action = callAction 0 $ initialInterval ri + where + callAction :: Int64 -> Int64 -> m a + callAction elapsed delay = action delay loop + where + loop = do + -- limit delay to max Int value (~36 minutes on for 32 bit architectures) + d <- registerDelay $ fromIntegral $ min delay (fromIntegral (maxBound :: Int)) + (wasForeground, wasOnline) <- atomically $ (,) <$> isForeground <*> isOnline + reset <- atomically $ do + foreground <- isForeground + online <- isOnline + let reset = (not wasForeground && foreground) || (not wasOnline && online) + unlessM ((reset ||) <$> readTVar d) retry + pure reset + let (elapsed', delay') + | reset = (0, initialInterval ri) + | otherwise = (elapsed + delay, nextRetryDelay elapsed' delay ri) + callAction elapsed' delay' + -- This function allows action to toggle between slow and fast retry intervals. withRetryLock2 :: forall m. MonadIO m => RetryInterval2 -> TMVar () -> (RI2State -> (RetryIntervalMode -> m ()) -> m ()) -> m () withRetryLock2 RetryInterval2 {riSlow, riFast} lock action = diff --git a/src/Simplex/Messaging/Agent/Stats.hs b/src/Simplex/Messaging/Agent/Stats.hs index 424052d74..d4663bfb1 100644 --- a/src/Simplex/Messaging/Agent/Stats.hs +++ b/src/Simplex/Messaging/Agent/Stats.hs @@ -1,17 +1,20 @@ {-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE TemplateHaskell #-} module Simplex.Messaging.Agent.Stats where +import Data.Aeson (FromJSON (..), FromJSONKey, ToJSON (..)) import qualified Data.Aeson.TH as J import Data.Int (Int64) -import Data.Map (Map) +import Data.Map.Strict (Map) +import qualified Data.Map.Strict as M import Database.SQLite.Simple.FromField (FromField (..)) import Database.SQLite.Simple.ToField (ToField (..)) import Simplex.Messaging.Agent.Protocol (UserId) import Simplex.Messaging.Parsers (defaultJSON, fromTextField_) -import Simplex.Messaging.Protocol (SMPServer, XFTPServer) +import Simplex.Messaging.Protocol (SMPServer, XFTPServer, NtfServer) import Simplex.Messaging.Util (decodeJSON, encodeJSON) import UnliftIO.STM @@ -44,7 +47,12 @@ data AgentSMPServerStats = AgentSMPServerStats connSubscribed :: TVar Int, -- total successful subscription connSubAttempts :: TVar Int, -- subscription attempts connSubIgnored :: TVar Int, -- subscription results ignored (client switched to different session or it was not pending) - connSubErrs :: TVar Int -- permanent subscription errors (temporary accounted for in attempts) + connSubErrs :: TVar Int, -- permanent subscription errors (temporary accounted for in attempts) + -- notifications stats + ntfKey :: TVar Int, + ntfKeyAttempts :: TVar Int, + ntfKeyDeleted :: TVar Int, + ntfKeyDeleteAttempts :: TVar Int } data AgentSMPServerStatsData = AgentSMPServerStatsData @@ -75,10 +83,17 @@ data AgentSMPServerStatsData = AgentSMPServerStatsData _connSubscribed :: Int, _connSubAttempts :: Int, _connSubIgnored :: Int, - _connSubErrs :: Int + _connSubErrs :: Int, + _ntfKey :: OptionalInt, + _ntfKeyAttempts :: OptionalInt, + _ntfKeyDeleted :: OptionalInt, + _ntfKeyDeleteAttempts :: OptionalInt } deriving (Show) +newtype OptionalInt = OInt {toInt :: Int} + deriving (Num, Show, ToJSON) + newAgentSMPServerStats :: STM AgentSMPServerStats newAgentSMPServerStats = do sentDirect <- newTVar 0 @@ -109,6 +124,10 @@ newAgentSMPServerStats = do connSubAttempts <- newTVar 0 connSubIgnored <- newTVar 0 connSubErrs <- newTVar 0 + ntfKey <- newTVar 0 + ntfKeyAttempts <- newTVar 0 + ntfKeyDeleted <- newTVar 0 + ntfKeyDeleteAttempts <- newTVar 0 pure AgentSMPServerStats { sentDirect, @@ -138,7 +157,11 @@ newAgentSMPServerStats = do connSubscribed, connSubAttempts, connSubIgnored, - connSubErrs + connSubErrs, + ntfKey, + ntfKeyAttempts, + ntfKeyDeleted, + ntfKeyDeleteAttempts } newAgentSMPServerStatsData :: AgentSMPServerStatsData @@ -171,7 +194,11 @@ newAgentSMPServerStatsData = _connSubscribed = 0, _connSubAttempts = 0, _connSubIgnored = 0, - _connSubErrs = 0 + _connSubErrs = 0, + _ntfKey = 0, + _ntfKeyAttempts = 0, + _ntfKeyDeleted = 0, + _ntfKeyDeleteAttempts = 0 } newAgentSMPServerStats' :: AgentSMPServerStatsData -> STM AgentSMPServerStats @@ -204,6 +231,10 @@ newAgentSMPServerStats' s = do connSubAttempts <- newTVar $ _connSubAttempts s connSubIgnored <- newTVar $ _connSubIgnored s connSubErrs <- newTVar $ _connSubErrs s + ntfKey <- newTVar $ toInt $ _ntfKey s + ntfKeyAttempts <- newTVar $ toInt $ _ntfKeyAttempts s + ntfKeyDeleted <- newTVar $ toInt $ _ntfKeyDeleted s + ntfKeyDeleteAttempts <- newTVar $ toInt $ _ntfKeyDeleteAttempts s pure AgentSMPServerStats { sentDirect, @@ -233,7 +264,11 @@ newAgentSMPServerStats' s = do connSubscribed, connSubAttempts, connSubIgnored, - connSubErrs + connSubErrs, + ntfKey, + ntfKeyAttempts, + ntfKeyDeleted, + ntfKeyDeleteAttempts } -- as this is used to periodically update stats in db, @@ -268,6 +303,10 @@ getAgentSMPServerStats s = do _connSubAttempts <- readTVarIO $ connSubAttempts s _connSubIgnored <- readTVarIO $ connSubIgnored s _connSubErrs <- readTVarIO $ connSubErrs s + _ntfKey <- OInt <$> readTVarIO (ntfKey s) + _ntfKeyAttempts <- OInt <$> readTVarIO (ntfKeyAttempts s) + _ntfKeyDeleted <- OInt <$> readTVarIO (ntfKeyDeleted s) + _ntfKeyDeleteAttempts <- OInt <$> readTVarIO (ntfKeyDeleteAttempts s) pure AgentSMPServerStatsData { _sentDirect, @@ -297,7 +336,11 @@ getAgentSMPServerStats s = do _connSubscribed, _connSubAttempts, _connSubIgnored, - _connSubErrs + _connSubErrs, + _ntfKey, + _ntfKeyAttempts, + _ntfKeyDeleted, + _ntfKeyDeleteAttempts } addSMPStatsData :: AgentSMPServerStatsData -> AgentSMPServerStatsData -> AgentSMPServerStatsData @@ -330,7 +373,11 @@ addSMPStatsData sd1 sd2 = _connSubscribed = _connSubscribed sd1 + _connSubscribed sd2, _connSubAttempts = _connSubAttempts sd1 + _connSubAttempts sd2, _connSubIgnored = _connSubIgnored sd1 + _connSubIgnored sd2, - _connSubErrs = _connSubErrs sd1 + _connSubErrs sd2 + _connSubErrs = _connSubErrs sd1 + _connSubErrs sd2, + _ntfKey = _ntfKey sd1 + _ntfKey sd2, + _ntfKeyAttempts = _ntfKeyAttempts sd1 + _ntfKeyAttempts sd2, + _ntfKeyDeleted = _ntfKeyDeleted sd1 + _ntfKeyDeleted sd2, + _ntfKeyDeleteAttempts = _ntfKeyDeleteAttempts sd1 + _ntfKeyDeleteAttempts sd2 } data AgentXFTPServerStats = AgentXFTPServerStats @@ -490,18 +537,127 @@ addXFTPStatsData sd1 sd2 = _deleteErrs = _deleteErrs sd1 + _deleteErrs sd2 } +data AgentNtfServerStats = AgentNtfServerStats + { ntfCreated :: TVar Int, + ntfCreateAttempts :: TVar Int, + ntfChecked :: TVar Int, + ntfCheckAttempts :: TVar Int, + ntfDeleted :: TVar Int, + ntfDelAttempts :: TVar Int + } + +data AgentNtfServerStatsData = AgentNtfServerStatsData + { _ntfCreated :: Int, + _ntfCreateAttempts :: Int, + _ntfChecked :: Int, + _ntfCheckAttempts :: Int, + _ntfDeleted :: Int, + _ntfDelAttempts :: Int + } + deriving (Show) + +newAgentNtfServerStats :: STM AgentNtfServerStats +newAgentNtfServerStats = do + ntfCreated <- newTVar 0 + ntfCreateAttempts <- newTVar 0 + ntfChecked <- newTVar 0 + ntfCheckAttempts <- newTVar 0 + ntfDeleted <- newTVar 0 + ntfDelAttempts <- newTVar 0 + pure + AgentNtfServerStats + { ntfCreated, + ntfCreateAttempts, + ntfChecked, + ntfCheckAttempts, + ntfDeleted, + ntfDelAttempts + } + +newAgentNtfServerStatsData :: AgentNtfServerStatsData +newAgentNtfServerStatsData = + AgentNtfServerStatsData + { _ntfCreated = 0, + _ntfCreateAttempts = 0, + _ntfChecked = 0, + _ntfCheckAttempts = 0, + _ntfDeleted = 0, + _ntfDelAttempts = 0 + } + +newAgentNtfServerStats' :: AgentNtfServerStatsData -> STM AgentNtfServerStats +newAgentNtfServerStats' s = do + ntfCreated <- newTVar $ _ntfCreated s + ntfCreateAttempts <- newTVar $ _ntfCreateAttempts s + ntfChecked <- newTVar $ _ntfChecked s + ntfCheckAttempts <- newTVar $ _ntfCheckAttempts s + ntfDeleted <- newTVar $ _ntfDeleted s + ntfDelAttempts <- newTVar $ _ntfDelAttempts s + pure + AgentNtfServerStats + { ntfCreated, + ntfCreateAttempts, + ntfChecked, + ntfCheckAttempts, + ntfDeleted, + ntfDelAttempts + } + +getAgentNtfServerStats :: AgentNtfServerStats -> IO AgentNtfServerStatsData +getAgentNtfServerStats s = do + _ntfCreated <- readTVarIO $ ntfCreated s + _ntfCreateAttempts <- readTVarIO $ ntfCreateAttempts s + _ntfChecked <- readTVarIO $ ntfChecked s + _ntfCheckAttempts <- readTVarIO $ ntfCheckAttempts s + _ntfDeleted <- readTVarIO $ ntfDeleted s + _ntfDelAttempts <- readTVarIO $ ntfDelAttempts s + pure + AgentNtfServerStatsData + { _ntfCreated, + _ntfCreateAttempts, + _ntfChecked, + _ntfCheckAttempts, + _ntfDeleted, + _ntfDelAttempts + } + +addNtfStatsData :: AgentNtfServerStatsData -> AgentNtfServerStatsData -> AgentNtfServerStatsData +addNtfStatsData sd1 sd2 = + AgentNtfServerStatsData + { _ntfCreated = _ntfCreated sd1 + _ntfCreated sd2, + _ntfCreateAttempts = _ntfCreateAttempts sd1 + _ntfCreateAttempts sd2, + _ntfChecked = _ntfChecked sd1 + _ntfChecked sd2, + _ntfCheckAttempts = _ntfCheckAttempts sd1 + _ntfCheckAttempts sd2, + _ntfDeleted = _ntfDeleted sd1 + _ntfDeleted sd2, + _ntfDelAttempts = _ntfDelAttempts sd1 + _ntfDelAttempts sd2 + } + -- Type for gathering both smp and xftp stats across all users and servers, -- to then be persisted to db as a single json. data AgentPersistedServerStats = AgentPersistedServerStats { smpServersStats :: Map (UserId, SMPServer) AgentSMPServerStatsData, - xftpServersStats :: Map (UserId, XFTPServer) AgentXFTPServerStatsData + xftpServersStats :: Map (UserId, XFTPServer) AgentXFTPServerStatsData, + ntfServersStats :: OptionalMap (UserId, NtfServer) AgentNtfServerStatsData } deriving (Show) +instance FromJSON OptionalInt where + parseJSON v = OInt <$> parseJSON v + omittedField = Just (OInt 0) + +newtype OptionalMap k v = OptionalMap (Map k v) + deriving (Show, ToJSON) + +instance (FromJSONKey k, Ord k, FromJSON v) => FromJSON (OptionalMap k v) where + parseJSON v = OptionalMap <$> parseJSON v + omittedField = Just (OptionalMap M.empty) + $(J.deriveJSON defaultJSON ''AgentSMPServerStatsData) $(J.deriveJSON defaultJSON ''AgentXFTPServerStatsData) +$(J.deriveJSON defaultJSON ''AgentNtfServerStatsData) + $(J.deriveJSON defaultJSON ''AgentPersistedServerStats) instance ToField AgentPersistedServerStats where diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index 0727343e7..e0e4fc58f 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -220,7 +220,7 @@ module Simplex.Messaging.Agent.Store.SQLite -- * utilities withConnection, withTransaction, - withTransactionCtx, + withTransactionPriority, firstRow, firstRow', maybeFirstRow, @@ -392,10 +392,10 @@ connectSQLiteStore dbFilePath key keepKey = do dbNew <- not <$> doesFileExist dbFilePath dbConn <- dbBusyLoop (connectDB dbFilePath key) dbConnection <- newMVar dbConn - atomically $ do - dbKey <- newTVar $! storeKey key keepKey - dbClosed <- newTVar False - pure SQLiteStore {dbFilePath, dbKey, dbConnection, dbNew, dbClosed} + dbKey <- newTVarIO $! storeKey key keepKey + dbClosed <- newTVarIO False + dbSem <- newTVarIO 0 + pure SQLiteStore {dbFilePath, dbKey, dbSem, dbConnection, dbNew, dbClosed} connectDB :: FilePath -> ScrubbedBytes -> IO DB.Connection connectDB path key = do @@ -1457,23 +1457,24 @@ getNtfSubscription db connId = DB.query db [sql| - SELECT s.host, s.port, COALESCE(nsb.smp_server_key_hash, s.key_hash), ns.ntf_host, ns.ntf_port, ns.ntf_key_hash, + SELECT c.user_id, s.host, s.port, COALESCE(nsb.smp_server_key_hash, s.key_hash), ns.ntf_host, ns.ntf_port, ns.ntf_key_hash, nsb.smp_ntf_id, nsb.ntf_sub_id, nsb.ntf_sub_status, nsb.ntf_sub_action, nsb.ntf_sub_smp_action, nsb.ntf_sub_action_ts FROM ntf_subscriptions nsb + JOIN connections c USING (conn_id) JOIN servers s ON s.host = nsb.smp_host AND s.port = nsb.smp_port JOIN ntf_servers ns USING (ntf_host, ntf_port) WHERE nsb.conn_id = ? |] (Only connId) where - ntfSubscription (smpHost, smpPort, smpKeyHash, ntfHost, ntfPort, ntfKeyHash, ntfQueueId, ntfSubId, ntfSubStatus, ntfAction_, smpAction_, actionTs_) = + ntfSubscription ((userId, smpHost, smpPort, smpKeyHash, ntfHost, ntfPort, ntfKeyHash ) :. (ntfQueueId, ntfSubId, ntfSubStatus, ntfAction_, smpAction_, actionTs_)) = let smpServer = SMPServer smpHost smpPort smpKeyHash ntfServer = NtfServer ntfHost ntfPort ntfKeyHash action = case (ntfAction_, smpAction_, actionTs_) of - (Just ntfAction, Nothing, Just actionTs) -> Just (NtfSubNTFAction ntfAction, actionTs) - (Nothing, Just smpAction, Just actionTs) -> Just (NtfSubSMPAction smpAction, actionTs) + (Just ntfAction, Nothing, Just actionTs) -> Just (NSANtf ntfAction, actionTs) + (Nothing, Just smpAction, Just actionTs) -> Just (NSASMP smpAction, actionTs) _ -> Nothing - in (NtfSubscription {connId, smpServer, ntfQueueId, ntfServer, ntfSubId, ntfSubStatus}, action) + in (NtfSubscription {userId, connId, smpServer, ntfQueueId, ntfServer, ntfSubId, ntfSubStatus}, action) createNtfSubscription :: DB.Connection -> NtfSubscription -> NtfSubAction -> IO (Either StoreError ()) createNtfSubscription db ntfSubscription action = runExceptT $ do @@ -1607,18 +1608,19 @@ getNextNtfSubNTFAction db ntfServer@(NtfServer ntfHost ntfPort _) = DB.query db [sql| - SELECT s.host, s.port, COALESCE(ns.smp_server_key_hash, s.key_hash), + SELECT c.user_id, s.host, s.port, COALESCE(ns.smp_server_key_hash, s.key_hash), ns.smp_ntf_id, ns.ntf_sub_id, ns.ntf_sub_status, ns.ntf_sub_action_ts, ns.ntf_sub_action FROM ntf_subscriptions ns + JOIN connections c USING (conn_id) JOIN servers s ON s.host = ns.smp_host AND s.port = ns.smp_port WHERE ns.conn_id = ? |] (Only connId) where err = SEInternal $ "ntf subscription " <> bshow connId <> " returned []" - ntfSubAction (smpHost, smpPort, smpKeyHash, ntfQueueId, ntfSubId, ntfSubStatus, actionTs, action) = + ntfSubAction (userId, smpHost, smpPort, smpKeyHash, ntfQueueId, ntfSubId, ntfSubStatus, actionTs, action) = let smpServer = SMPServer smpHost smpPort smpKeyHash - ntfSubscription = NtfSubscription {connId, smpServer, ntfQueueId, ntfServer, ntfSubId, ntfSubStatus} + ntfSubscription = NtfSubscription {userId, connId, smpServer, ntfQueueId, ntfServer, ntfSubId, ntfSubStatus} in (ntfSubscription, action, actionTs) markNtfSubActionNtfFailed_ :: DB.Connection -> ConnId -> IO () @@ -1650,18 +1652,19 @@ getNextNtfSubSMPAction db smpServer@(SMPServer smpHost smpPort _) = DB.query db [sql| - SELECT s.ntf_host, s.ntf_port, s.ntf_key_hash, + SELECT c.user_id, s.ntf_host, s.ntf_port, s.ntf_key_hash, ns.smp_ntf_id, ns.ntf_sub_id, ns.ntf_sub_status, ns.ntf_sub_action_ts, ns.ntf_sub_smp_action FROM ntf_subscriptions ns + JOIN connections c USING (conn_id) JOIN ntf_servers s USING (ntf_host, ntf_port) WHERE ns.conn_id = ? |] (Only connId) where err = SEInternal $ "ntf subscription " <> bshow connId <> " returned []" - ntfSubAction (ntfHost, ntfPort, ntfKeyHash, ntfQueueId, ntfSubId, ntfSubStatus, actionTs, action) = + ntfSubAction (userId, ntfHost, ntfPort, ntfKeyHash, ntfQueueId, ntfSubId, ntfSubStatus, actionTs, action) = let ntfServer = NtfServer ntfHost ntfPort ntfKeyHash - ntfSubscription = NtfSubscription {connId, smpServer, ntfQueueId, ntfServer, ntfSubId, ntfSubStatus} + ntfSubscription = NtfSubscription {userId, connId, smpServer, ntfQueueId, ntfServer, ntfSubId, ntfSubStatus} in (ntfSubscription, action, actionTs) markNtfSubActionSMPFailed_ :: DB.Connection -> ConnId -> IO () @@ -1906,9 +1909,11 @@ newQueueId_ (Only maxId : _) = DBQueueId (maxId + 1) getConn :: DB.Connection -> ConnId -> IO (Either StoreError SomeConn) getConn = getAnyConn False +{-# INLINE getConn #-} getDeletedConn :: DB.Connection -> ConnId -> IO (Either StoreError SomeConn) getDeletedConn = getAnyConn True +{-# INLINE getDeletedConn #-} getAnyConn :: Bool -> DB.Connection -> ConnId -> IO (Either StoreError SomeConn) getAnyConn deleted' dbConn connId = @@ -1929,9 +1934,11 @@ getAnyConn deleted' dbConn connId = getConns :: DB.Connection -> [ConnId] -> IO [Either StoreError SomeConn] getConns = getAnyConns_ False +{-# INLINE getConns #-} getDeletedConns :: DB.Connection -> [ConnId] -> IO [Either StoreError SomeConn] getDeletedConns = getAnyConns_ True +{-# INLINE getDeletedConns #-} getAnyConns_ :: Bool -> DB.Connection -> [ConnId] -> IO [Either StoreError SomeConn] getAnyConns_ deleted' db connIds = forM connIds $ E.handle handleDBError . getAnyConn deleted' db @@ -2272,8 +2279,8 @@ randomId :: TVar ChaChaDRG -> Int -> IO ByteString randomId gVar n = atomically $ U.encode <$> C.randomBytes n gVar ntfSubAndSMPAction :: NtfSubAction -> (Maybe NtfSubNTFAction, Maybe NtfSubSMPAction) -ntfSubAndSMPAction (NtfSubNTFAction action) = (Just action, Nothing) -ntfSubAndSMPAction (NtfSubSMPAction action) = (Nothing, Just action) +ntfSubAndSMPAction (NSANtf action) = (Just action, Nothing) +ntfSubAndSMPAction (NSASMP action) = (Nothing, Just action) createXFTPServer_ :: DB.Connection -> XFTPServer -> IO Int64 createXFTPServer_ db newSrv@ProtocolServer {host, port, keyHash} = diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Common.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Common.hs index b9a9bd501..a7ad47f37 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/Common.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Common.hs @@ -1,4 +1,5 @@ {-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -8,20 +9,20 @@ module Simplex.Messaging.Agent.Store.SQLite.Common withConnection', withTransaction, withTransaction', - withTransactionCtx, + withTransactionPriority, dbBusyLoop, storeKey, ) where import Control.Concurrent (threadDelay) +import Control.Concurrent.STM (retry) import Data.ByteArray (ScrubbedBytes) import qualified Data.ByteArray as BA -import Data.Time.Clock (diffUTCTime, getCurrentTime) import Database.SQLite.Simple (SQLError) import qualified Database.SQLite.Simple as SQL import qualified Simplex.Messaging.Agent.Store.SQLite.DB as DB -import Simplex.Messaging.Util (diffToMilliseconds) +import Simplex.Messaging.Util (ifM, unlessM) import qualified UnliftIO.Exception as E import UnliftIO.MVar import UnliftIO.STM @@ -32,35 +33,40 @@ storeKey key keepKey = if keepKey || BA.null key then Just key else Nothing data SQLiteStore = SQLiteStore { dbFilePath :: FilePath, dbKey :: TVar (Maybe ScrubbedBytes), + dbSem :: TVar Int, dbConnection :: MVar DB.Connection, dbClosed :: TVar Bool, dbNew :: Bool } +withConnectionPriority :: SQLiteStore -> Bool -> (DB.Connection -> IO a) -> IO a +withConnectionPriority SQLiteStore {dbSem, dbConnection} priority action + | priority = E.bracket_ signal release $ withMVar dbConnection action + | otherwise = lowPriority + where + lowPriority = wait >> withMVar dbConnection (\db -> ifM free (Just <$> action db) (pure Nothing)) >>= maybe lowPriority pure + signal = atomically $ modifyTVar' dbSem (+ 1) + release = atomically $ modifyTVar' dbSem $ \sem -> if sem > 0 then sem - 1 else 0 + wait = unlessM free $ atomically $ unlessM ((0 ==) <$> readTVar dbSem) retry + free = (0 ==) <$> readTVarIO dbSem + withConnection :: SQLiteStore -> (DB.Connection -> IO a) -> IO a -withConnection SQLiteStore {dbConnection} = withMVar dbConnection +withConnection st = withConnectionPriority st False withConnection' :: SQLiteStore -> (SQL.Connection -> IO a) -> IO a withConnection' st action = withConnection st $ action . DB.conn -withTransaction :: SQLiteStore -> (DB.Connection -> IO a) -> IO a -withTransaction = withTransactionCtx Nothing - withTransaction' :: SQLiteStore -> (SQL.Connection -> IO a) -> IO a withTransaction' st action = withTransaction st $ action . DB.conn -withTransactionCtx :: Maybe String -> SQLiteStore -> (DB.Connection -> IO a) -> IO a -withTransactionCtx ctx_ st action = withConnection st $ dbBusyLoop . transactionWithCtx +withTransaction :: SQLiteStore -> (DB.Connection -> IO a) -> IO a +withTransaction st = withTransactionPriority st False +{-# INLINE withTransaction #-} + +withTransactionPriority :: SQLiteStore -> Bool -> (DB.Connection -> IO a) -> IO a +withTransactionPriority st priority action = withConnectionPriority st priority $ dbBusyLoop . transaction where - transactionWithCtx db@DB.Connection {conn} = case ctx_ of - Nothing -> SQL.withImmediateTransaction conn $ action db - Just ctx -> do - t1 <- getCurrentTime - r <- SQL.withImmediateTransaction conn $ action db - t2 <- getCurrentTime - putStrLn $ "withTransactionCtx start :: " <> show t1 <> " :: " <> ctx - putStrLn $ "withTransactionCtx end :: " <> show t2 <> " :: " <> ctx <> " :: duration=" <> show (diffToMilliseconds $ diffUTCTime t2 t1) - pure r + transaction db@DB.Connection {conn} = SQL.withImmediateTransaction conn $ action db dbBusyLoop :: forall a. IO a -> IO a dbBusyLoop action = loop 500 3000000 diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/DB.hs b/src/Simplex/Messaging/Agent/Store/SQLite/DB.hs index 2ae4eb731..b356b3f87 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/DB.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite/DB.hs @@ -64,7 +64,7 @@ timeIt slow sql a = do open :: String -> IO Connection open f = do conn <- SQL.open f - slow <- atomically $ TM.empty + slow <- TM.emptyIO pure Connection {conn, slow} close :: Connection -> IO () diff --git a/src/Simplex/Messaging/Agent/TRcvQueues.hs b/src/Simplex/Messaging/Agent/TRcvQueues.hs index 9ffe325b2..3b02f64ae 100644 --- a/src/Simplex/Messaging/Agent/TRcvQueues.hs +++ b/src/Simplex/Messaging/Agent/TRcvQueues.hs @@ -1,7 +1,9 @@ +{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE LambdaCase #-} module Simplex.Messaging.Agent.TRcvQueues ( TRcvQueues (getRcvQueues, getConnections), + Queue (..), empty, clear, deleteConn, @@ -9,9 +11,9 @@ module Simplex.Messaging.Agent.TRcvQueues addQueue, batchAddQueues, deleteQueue, + hasSessQueues, getSessQueues, getDelSessQueues, - qKey, ) where @@ -25,46 +27,51 @@ import Simplex.Messaging.Agent.Store (RcvQueue, StoredRcvQueue (..)) import Simplex.Messaging.Protocol (RecipientId, SMPServer) import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM +import Simplex.Messaging.Transport + +class Queue q where + connId' :: q -> ConnId + qKey :: q -> (UserId, SMPServer, ConnId) -- the fields in this record have the same data with swapped keys for lookup efficiency, -- and all methods must maintain this invariant. -data TRcvQueues = TRcvQueues - { getRcvQueues :: TMap (UserId, SMPServer, RecipientId) RcvQueue, +data TRcvQueues q = TRcvQueues + { getRcvQueues :: TMap (UserId, SMPServer, RecipientId) q, getConnections :: TMap ConnId (NonEmpty (UserId, SMPServer, RecipientId)) } -empty :: STM TRcvQueues -empty = TRcvQueues <$> TM.empty <*> TM.empty +empty :: IO (TRcvQueues q) +empty = TRcvQueues <$> TM.emptyIO <*> TM.emptyIO -clear :: TRcvQueues -> STM () +clear :: TRcvQueues q -> STM () clear (TRcvQueues qs cs) = TM.clear qs >> TM.clear cs -deleteConn :: ConnId -> TRcvQueues -> STM () +deleteConn :: ConnId -> TRcvQueues q -> STM () deleteConn cId (TRcvQueues qs cs) = TM.lookupDelete cId cs >>= \case Just ks -> modifyTVar' qs $ \qs' -> foldl' (flip M.delete) qs' ks Nothing -> pure () -hasConn :: ConnId -> TRcvQueues -> STM Bool +hasConn :: ConnId -> TRcvQueues q -> STM Bool hasConn cId (TRcvQueues _ cs) = TM.member cId cs -addQueue :: RcvQueue -> TRcvQueues -> STM () +addQueue :: Queue q => q -> TRcvQueues q -> STM () addQueue rq (TRcvQueues qs cs) = do TM.insert k rq qs - TM.alter addQ (connId rq) cs + TM.alter addQ (connId' rq) cs where addQ = Just . maybe (k :| []) (k <|) k = qKey rq -- Save time by aggregating modifyTVar -batchAddQueues :: Foldable t => TRcvQueues -> t RcvQueue -> STM () +batchAddQueues :: (Foldable t, Queue q) => TRcvQueues q -> t q -> STM () batchAddQueues (TRcvQueues qs cs) rqs = do modifyTVar' qs $ \now -> foldl' (\rqs' rq -> M.insert (qKey rq) rq rqs') now rqs - modifyTVar' cs $ \now -> foldl' (\cs' rq -> M.alter (addQ $ qKey rq) (connId rq) cs') now rqs + modifyTVar' cs $ \now -> foldl' (\cs' rq -> M.alter (addQ $ qKey rq) (connId' rq) cs') now rqs where addQ k = Just . maybe (k :| []) (k <|) -deleteQueue :: RcvQueue -> TRcvQueues -> STM () +deleteQueue :: RcvQueue -> TRcvQueues RcvQueue -> STM () deleteQueue rq (TRcvQueues qs cs) = do TM.delete k qs TM.update delQ (connId rq) cs @@ -72,21 +79,25 @@ deleteQueue rq (TRcvQueues qs cs) = do delQ = L.nonEmpty . L.filter (/= k) k = qKey rq -getSessQueues :: (UserId, SMPServer, Maybe ConnId) -> TRcvQueues -> STM [RcvQueue] -getSessQueues tSess (TRcvQueues qs _) = M.foldl' addQ [] <$> readTVar qs +hasSessQueues :: (UserId, SMPServer, Maybe ConnId) -> TRcvQueues RcvQueue -> STM Bool +hasSessQueues tSess (TRcvQueues qs _) = any (`isSession` tSess) <$> readTVar qs + +getSessQueues :: (UserId, SMPServer, Maybe ConnId) -> TRcvQueues RcvQueue -> IO [RcvQueue] +getSessQueues tSess (TRcvQueues qs _) = M.foldl' addQ [] <$> readTVarIO qs where addQ qs' rq = if rq `isSession` tSess then rq : qs' else qs' -getDelSessQueues :: (UserId, SMPServer, Maybe ConnId) -> TRcvQueues -> STM ([RcvQueue], [ConnId]) -getDelSessQueues tSess (TRcvQueues qs cs) = do +getDelSessQueues :: (UserId, SMPServer, Maybe ConnId) -> SessionId -> TRcvQueues (SessionId, RcvQueue) -> STM ([RcvQueue], [ConnId]) +getDelSessQueues tSess sessId' (TRcvQueues qs cs) = do (removedQs, qs'') <- (\qs' -> M.foldl' delQ ([], qs') qs') <$> readTVar qs writeTVar qs $! qs'' removedConns <- stateTVar cs $ \cs' -> foldl' delConn ([], cs') removedQs pure (removedQs, removedConns) where - delQ acc@(removed, qs') rq - | rq `isSession` tSess = (rq : removed, M.delete (qKey rq) qs') + delQ acc@(removed, qs') (sessId, rq) + | rq `isSession` tSess && sessId == sessId' = (rq : removed, M.delete (qKey rq) qs') | otherwise = acc + delConn :: ([ConnId], M.Map ConnId (NonEmpty (UserId, SMPServer, ConnId))) -> RcvQueue -> ([ConnId], M.Map ConnId (NonEmpty (UserId, SMPServer, ConnId))) delConn (removed, cs') rq = M.alterF f cId cs' where cId = connId rq @@ -100,5 +111,10 @@ isSession :: RcvQueue -> (UserId, SMPServer, Maybe ConnId) -> Bool isSession rq (uId, srv, connId_) = userId rq == uId && server rq == srv && maybe True (connId rq ==) connId_ -qKey :: RcvQueue -> (UserId, SMPServer, ConnId) -qKey rq = (userId rq, server rq, connId rq) +instance Queue RcvQueue where + connId' = connId + qKey rq = (userId rq, server rq, connId rq) + +instance Queue (SessionId, RcvQueue) where + connId' = connId . snd + qKey = qKey . snd diff --git a/src/Simplex/Messaging/Client.hs b/src/Simplex/Messaging/Client.hs index 80fd65ffc..b4567c62e 100644 --- a/src/Simplex/Messaging/Client.hs +++ b/src/Simplex/Messaging/Client.hs @@ -170,17 +170,17 @@ data PClient v err msg = PClient msgQ :: Maybe (TBQueue (ServerTransmissionBatch v err msg)) } -smpClientStub :: TVar ChaChaDRG -> ByteString -> VersionSMP -> Maybe (THandleAuth 'TClient) -> STM SMPClient +smpClientStub :: TVar ChaChaDRG -> ByteString -> VersionSMP -> Maybe (THandleAuth 'TClient) -> IO SMPClient smpClientStub g sessionId thVersion thAuth = do let ts = UTCTime (read "2024-03-31") 0 - connected <- newTVar False - clientCorrId <- C.newRandomDRG g - sentCommands <- TM.empty - sendPings <- newTVar False - lastReceived <- newTVar ts - timeoutErrorCount <- newTVar 0 - sndQ <- newTBQueue 100 - rcvQ <- newTBQueue 100 + connected <- newTVarIO False + clientCorrId <- atomically $ C.newRandomDRG g + sentCommands <- TM.emptyIO + sendPings <- newTVarIO False + lastReceived <- newTVarIO ts + timeoutErrorCount <- newTVarIO 0 + sndQ <- newTBQueueIO 100 + rcvQ <- newTBQueueIO 100 return ProtocolClient { action = Nothing, @@ -452,21 +452,21 @@ getProtocolClient :: forall v err msg. Protocol v err msg => TVar ChaChaDRG -> T getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize, networkConfig, clientALPN, serverVRange, agreeSecret} msgQ disconnected = do case chooseTransportHost networkConfig (host srv) of Right useHost -> - (getCurrentTime >>= atomically . mkProtocolClient useHost >>= runClient useTransport useHost) + (getCurrentTime >>= mkProtocolClient useHost >>= runClient useTransport useHost) `catch` \(e :: IOException) -> pure . Left $ PCEIOError e Left e -> pure $ Left e where NetworkConfig {tcpConnectTimeout, tcpTimeout, smpPingInterval} = networkConfig - mkProtocolClient :: TransportHost -> UTCTime -> STM (PClient v err msg) + mkProtocolClient :: TransportHost -> UTCTime -> IO (PClient v err msg) mkProtocolClient transportHost ts = do - connected <- newTVar False - sendPings <- newTVar False - lastReceived <- newTVar ts - timeoutErrorCount <- newTVar 0 - clientCorrId <- C.newRandomDRG g - sentCommands <- TM.empty - sndQ <- newTBQueue qSize - rcvQ <- newTBQueue qSize + connected <- newTVarIO False + sendPings <- newTVarIO False + lastReceived <- newTVarIO ts + timeoutErrorCount <- newTVarIO 0 + clientCorrId <- atomically $ C.newRandomDRG g + sentCommands <- TM.emptyIO + sndQ <- newTBQueueIO qSize + rcvQ <- newTBQueueIO qSize return PClient { connected, @@ -565,7 +565,7 @@ getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize processMsg ProtocolClient {client_ = PClient {sentCommands}} (_, _, (corrId, entId, respOrErr)) | B.null $ bs corrId = sendMsg $ STEvent clientResp | otherwise = - atomically (TM.lookup corrId sentCommands) >>= \case + TM.lookupIO corrId sentCommands >>= \case Nothing -> sendMsg $ STUnexpectedError unexpected Just Request {entityId, command, pending, responseVar} -> do wasPending <- @@ -1089,13 +1089,13 @@ mkTransmission_ ProtocolClient {thParams, client_ = PClient {clientCorrId, sentC nonce@(C.CbNonce corrId) <- maybe (atomically $ C.randomCbNonce clientCorrId) pure nonce_ let TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth thParams (CorrId corrId, entityId, command) auth = authTransmission (thAuth thParams) pKey_ nonce tForAuth - r <- atomically $ mkRequest (CorrId corrId) + r <- mkRequest (CorrId corrId) pure ((,tToSend) <$> auth, r) where - mkRequest :: CorrId -> STM (Request err msg) + mkRequest :: CorrId -> IO (Request err msg) mkRequest corrId = do - pending <- newTVar True - responseVar <- newEmptyTMVar + pending <- newTVarIO True + responseVar <- newEmptyTMVarIO let r = Request { corrId, @@ -1104,7 +1104,7 @@ mkTransmission_ ProtocolClient {thParams, client_ = PClient {clientCorrId, sentC pending, responseVar } - TM.insert corrId r sentCommands + atomically $ TM.insert corrId r sentCommands pure r authTransmission :: Maybe (THandleAuth 'TClient) -> Maybe C.APrivateAuthKey -> C.CbNonce -> ByteString -> Either TransportError (Maybe TransmissionAuth) diff --git a/src/Simplex/Messaging/Client/Agent.hs b/src/Simplex/Messaging/Client/Agent.hs index 99c77f67c..8073f1d48 100644 --- a/src/Simplex/Messaging/Client/Agent.hs +++ b/src/Simplex/Messaging/Client/Agent.hs @@ -100,7 +100,7 @@ data SMPClientAgent = SMPClientAgent randomDrg :: TVar ChaChaDRG, smpClients :: TMap SMPServer SMPClientVar, smpSessions :: TMap SessionId (OwnServer, SMPClient), - srvSubs :: TMap SMPServer (TMap SMPSub C.APrivateAuthKey), + srvSubs :: TMap SMPServer (TMap SMPSub (SessionId, C.APrivateAuthKey)), pendingSrvSubs :: TMap SMPServer (TMap SMPSub C.APrivateAuthKey), smpSubWorkers :: TMap SMPServer (SessionVar (Async ())), workerSeq :: TVar Int @@ -108,17 +108,17 @@ data SMPClientAgent = SMPClientAgent type OwnServer = Bool -newSMPClientAgent :: SMPClientAgentConfig -> TVar ChaChaDRG -> STM SMPClientAgent +newSMPClientAgent :: SMPClientAgentConfig -> TVar ChaChaDRG -> IO SMPClientAgent newSMPClientAgent agentCfg@SMPClientAgentConfig {msgQSize, agentQSize} randomDrg = do - active <- newTVar True - msgQ <- newTBQueue msgQSize - agentQ <- newTBQueue agentQSize - smpClients <- TM.empty - smpSessions <- TM.empty - srvSubs <- TM.empty - pendingSrvSubs <- TM.empty - smpSubWorkers <- TM.empty - workerSeq <- newTVar 0 + active <- newTVarIO True + msgQ <- newTBQueueIO msgQSize + agentQ <- newTBQueueIO agentQSize + smpClients <- TM.emptyIO + smpSessions <- TM.emptyIO + srvSubs <- TM.emptyIO + pendingSrvSubs <- TM.emptyIO + smpSubWorkers <- TM.emptyIO + workerSeq <- newTVarIO 0 pure SMPClientAgent { agentCfg, @@ -204,14 +204,17 @@ connectClient ca@SMPClientAgent {agentCfg, smpClients, smpSessions, msgQ, random removeClientAndSubs :: SMPClient -> IO (Maybe (Map SMPSub C.APrivateAuthKey)) removeClientAndSubs smp = atomically $ do + TM.delete sessId smpSessions removeSessVar v srv smpClients - TM.delete (sessionId $ thParams smp) smpSessions - TM.lookupDelete srv (srvSubs ca) >>= mapM updateSubs + TM.lookup srv (srvSubs ca) >>= mapM updateSubs where + sessId = sessionId $ thParams smp updateSubs sVar = do - ss <- readTVar sVar - addSubs_ (pendingSrvSubs ca) srv ss - pure ss + -- removing subscriptions that have matching sessionId to disconnected client + -- and keep the other ones (they can be made by the new client) + pending <- M.map snd <$> stateTVar sVar (M.partition ((sessId ==) . fst)) + addSubs_ (pendingSrvSubs ca) srv pending + pure pending serverDown :: Map SMPSub C.APrivateAuthKey -> IO () serverDown ss = unless (M.null ss) $ do @@ -226,7 +229,7 @@ reconnectClient ca@SMPClientAgent {active, agentCfg, smpSubWorkers, workerSeq} s where getWorkerVar ts = ifM - (null <$> getPending) + (noPending) (pure Nothing) -- prevent race with cleanup and adding pending queues in another call (Just <$> getSessVar workerSeq srv smpSubWorkers ts) newSubWorker :: SessionVar (Async ()) -> IO () @@ -235,12 +238,13 @@ reconnectClient ca@SMPClientAgent {active, agentCfg, smpSubWorkers, workerSeq} s atomically $ putTMVar (sessionVar v) a runSubWorker = withRetryInterval (reconnectInterval agentCfg) $ \_ loop -> do - pending <- atomically getPending + pending <- liftIO getPending unless (null pending) $ whenM (readTVarIO active) $ do void $ tcpConnectTimeout `timeout` runExceptT (reconnectSMPClient ca srv pending) loop ProtocolClientConfig {networkConfig = NetworkConfig {tcpConnectTimeout}} = smpCfg agentCfg - getPending = maybe (pure M.empty) readTVar =<< TM.lookup srv (pendingSrvSubs ca) + noPending = maybe (pure True) (fmap M.null . readTVar) =<< TM.lookup srv (pendingSrvSubs ca) + getPending = maybe (pure M.empty) readTVarIO =<< TM.lookupIO srv (pendingSrvSubs ca) cleanup :: SessionVar (Async ()) -> STM () cleanup v = do -- Here we wait until TMVar is not empty to prevent worker cleanup happening before worker is added to TMVar. @@ -251,14 +255,14 @@ reconnectClient ca@SMPClientAgent {active, agentCfg, smpSubWorkers, workerSeq} s reconnectSMPClient :: SMPClientAgent -> SMPServer -> Map SMPSub C.APrivateAuthKey -> ExceptT SMPClientError IO () reconnectSMPClient ca@SMPClientAgent {agentCfg} srv cs = withSMP ca srv $ \smp -> liftIO $ do - currSubs <- atomically $ maybe (pure M.empty) readTVar =<< TM.lookup srv (srvSubs ca) + currSubs <- maybe (pure M.empty) readTVarIO =<< TM.lookupIO srv (srvSubs ca) let (nSubs, rSubs) = foldr (groupSub currSubs) ([], []) $ M.assocs cs subscribe_ smp SPNotifier nSubs subscribe_ smp SPRecipient rSubs where - groupSub :: Map SMPSub C.APrivateAuthKey -> (SMPSub, C.APrivateAuthKey) -> ([(QueueId, C.APrivateAuthKey)], [(QueueId, C.APrivateAuthKey)]) -> ([(QueueId, C.APrivateAuthKey)], [(QueueId, C.APrivateAuthKey)]) - groupSub currSubs (s@(party, qId), k) (nSubs, rSubs) - | M.member s currSubs = (nSubs, rSubs) + groupSub :: Map SMPSub (SessionId, C.APrivateAuthKey) -> (SMPSub, C.APrivateAuthKey) -> ([(QueueId, C.APrivateAuthKey)], [(QueueId, C.APrivateAuthKey)]) -> ([(QueueId, C.APrivateAuthKey)], [(QueueId, C.APrivateAuthKey)]) + groupSub currSubs (s@(party, qId), k) acc@(nSubs, rSubs) + | M.member s currSubs = acc | otherwise = case party of SPNotifier -> (s' : nSubs, rSubs) SPRecipient -> (nSubs, s' : rSubs) @@ -286,8 +290,8 @@ getConnectedSMPServerClient SMPClientAgent {smpClients} srv = (Nothing <$ atomically (removeSessVar v srv smpClients)) -- proxy will create a new connection (pure $ Just $ Left e) -- not expired, returning error -lookupSMPServerClient :: SMPClientAgent -> SessionId -> STM (Maybe (OwnServer, SMPClient)) -lookupSMPServerClient SMPClientAgent {smpSessions} sessId = TM.lookup sessId smpSessions +lookupSMPServerClient :: SMPClientAgent -> SessionId -> IO (Maybe (OwnServer, SMPClient)) +lookupSMPServerClient SMPClientAgent {smpSessions} sessId = TM.lookupIO sessId smpSessions closeSMPClientAgent :: SMPClientAgent -> IO () closeSMPClientAgent c = do @@ -346,17 +350,18 @@ smpSubscribeQueues party ca smp srv subs = do when tempErrs $ reconnectClient ca srv Nothing -> reconnectClient ca srv where - processSubscriptions :: NonEmpty (Either SMPClientError ()) -> STM (Bool, [(QueueId, SMPClientError)], [(QueueId, C.APrivateAuthKey)], [QueueId]) + processSubscriptions :: NonEmpty (Either SMPClientError ()) -> STM (Bool, [(QueueId, SMPClientError)], [(QueueId, (SessionId, C.APrivateAuthKey))], [QueueId]) processSubscriptions rs = do pending <- maybe (pure M.empty) readTVar =<< TM.lookup srv (pendingSrvSubs ca) let acc@(_, _, oks, notPending) = foldr (groupSub pending) (False, [], [], []) (L.zip subs rs) unless (null oks) $ addSubscriptions ca srv party oks unless (null notPending) $ removePendingSubs ca srv party notPending pure acc - groupSub :: Map SMPSub C.APrivateAuthKey -> ((QueueId, C.APrivateAuthKey), Either SMPClientError ()) -> (Bool, [(QueueId, SMPClientError)], [(QueueId, C.APrivateAuthKey)], [QueueId]) -> (Bool, [(QueueId, SMPClientError)], [(QueueId, C.APrivateAuthKey)], [QueueId]) - groupSub pending (s@(qId, _), r) acc@(!tempErrs, finalErrs, oks, notPending) = case r of + sessId = sessionId $ thParams smp + groupSub :: Map SMPSub C.APrivateAuthKey -> ((QueueId, C.APrivateAuthKey), Either SMPClientError ()) -> (Bool, [(QueueId, SMPClientError)], [(QueueId, (SessionId, C.APrivateAuthKey))], [QueueId]) -> (Bool, [(QueueId, SMPClientError)], [(QueueId, (SessionId, C.APrivateAuthKey))], [QueueId]) + groupSub pending ((qId, pk), r) acc@(!tempErrs, finalErrs, oks, notPending) = case r of Right () - | M.member (party, qId) pending -> (tempErrs, finalErrs, s : oks, qId : notPending) + | M.member (party, qId) pending -> (tempErrs, finalErrs, (qId, (sessId, pk)) : oks, qId : notPending) | otherwise -> acc Left e | temporaryClientError e -> (True, finalErrs, oks, notPending) @@ -379,7 +384,7 @@ showServer :: SMPServer -> ByteString showServer ProtocolServer {host, port} = strEncode host <> B.pack (if null port then "" else ':' : port) -addSubscriptions :: SMPClientAgent -> SMPServer -> SMPSubParty -> [(QueueId, C.APrivateAuthKey)] -> STM () +addSubscriptions :: SMPClientAgent -> SMPServer -> SMPSubParty -> [(QueueId, (SessionId, C.APrivateAuthKey))] -> STM () addSubscriptions = addSubsList_ . srvSubs {-# INLINE addSubscriptions #-} @@ -387,12 +392,12 @@ addPendingSubs :: SMPClientAgent -> SMPServer -> SMPSubParty -> [(QueueId, C.APr addPendingSubs = addSubsList_ . pendingSrvSubs {-# INLINE addPendingSubs #-} -addSubsList_ :: TMap SMPServer (TMap SMPSub C.APrivateAuthKey) -> SMPServer -> SMPSubParty -> [(QueueId, C.APrivateAuthKey)] -> STM () +addSubsList_ :: TMap SMPServer (TMap SMPSub s) -> SMPServer -> SMPSubParty -> [(QueueId, s)] -> STM () addSubsList_ subs srv party ss = addSubs_ subs srv ss' where ss' = M.fromList $ map (first (party,)) ss -addSubs_ :: TMap SMPServer (TMap SMPSub C.APrivateAuthKey) -> SMPServer -> Map SMPSub C.APrivateAuthKey -> STM () +addSubs_ :: TMap SMPServer (TMap SMPSub s) -> SMPServer -> Map SMPSub s -> STM () addSubs_ subs srv ss = TM.lookup srv subs >>= \case Just m -> TM.union ss m @@ -402,7 +407,7 @@ removeSubscription :: SMPClientAgent -> SMPServer -> SMPSub -> STM () removeSubscription = removeSub_ . srvSubs {-# INLINE removeSubscription #-} -removeSub_ :: TMap SMPServer (TMap SMPSub C.APrivateAuthKey) -> SMPServer -> SMPSub -> STM () +removeSub_ :: TMap SMPServer (TMap SMPSub s) -> SMPServer -> SMPSub -> STM () removeSub_ subs srv s = TM.lookup srv subs >>= mapM_ (TM.delete s) removePendingSubs :: SMPClientAgent -> SMPServer -> SMPSubParty -> [QueueId] -> STM () diff --git a/src/Simplex/Messaging/Notifications/Server.hs b/src/Simplex/Messaging/Notifications/Server.hs index 2bf8dbcbf..1192148ac 100644 --- a/src/Simplex/Messaging/Notifications/Server.hs +++ b/src/Simplex/Messaging/Notifications/Server.hs @@ -176,10 +176,10 @@ ntfSubscriber NtfSubscriber {smpSubscribers, newSubQ, smpAgent = ca@SMPClientAge getSMPSubscriber :: SMPServer -> M SMPSubscriber getSMPSubscriber smpServer = - atomically (TM.lookup smpServer smpSubscribers) >>= maybe createSMPSubscriber pure + liftIO (TM.lookupIO smpServer smpSubscribers) >>= maybe createSMPSubscriber pure where createSMPSubscriber = do - sub@SMPSubscriber {subThreadId} <- atomically newSMPSubscriber + sub@SMPSubscriber {subThreadId} <- liftIO newSMPSubscriber atomically $ TM.insert smpServer sub smpSubscribers tId <- mkWeakThreadId =<< forkIO (runSMPSubscriber sub) atomically . writeTVar subThreadId $ Just tId @@ -333,7 +333,7 @@ runNtfClientTransport :: Transport c => THandleNTF c 'TServer -> M () runNtfClientTransport th@THandle {params} = do qSize <- asks $ clientQSize . config ts <- liftIO getSystemTime - c <- atomically $ newNtfServerClient qSize params ts + c <- liftIO $ newNtfServerClient qSize params ts s <- asks subscriber ps <- asks pushServer expCfg <- asks $ inactiveClientExpiration . config @@ -507,7 +507,7 @@ client NtfServerClient {rcvQ, sndQ} NtfSubscriber {newSubQ, smpAgent = ca} NtfPu | otherwise -> do logDebug "TCRN" atomically $ writeTVar tknCronInterval int - atomically (TM.lookup tknId intervalNotifiers) >>= \case + liftIO (TM.lookupIO tknId intervalNotifiers) >>= \case Nothing -> runIntervalNotifier int Just IntervalNotifier {interval, action} -> unless (interval == int) $ do @@ -585,7 +585,7 @@ incNtfStat statSel = do saveServerStats :: M () saveServerStats = asks (serverStatsBackupFile . config) - >>= mapM_ (\f -> asks serverStats >>= atomically . getNtfServerStatsData >>= liftIO . saveStats f) + >>= mapM_ (\f -> asks serverStats >>= liftIO . getNtfServerStatsData >>= liftIO . saveStats f) where saveStats f stats = do logInfo $ "saving server stats to file " <> T.pack f diff --git a/src/Simplex/Messaging/Notifications/Server/Env.hs b/src/Simplex/Messaging/Notifications/Server/Env.hs index 5ebd5230e..dc0cb0a73 100644 --- a/src/Simplex/Messaging/Notifications/Server/Env.hs +++ b/src/Simplex/Messaging/Notifications/Server/Env.hs @@ -10,7 +10,6 @@ module Simplex.Messaging.Notifications.Server.Env where import Control.Concurrent (ThreadId) import Control.Concurrent.Async (Async) import Control.Logger.Simple -import Control.Monad.IO.Unlift import Crypto.Random import Data.Int (Int64) import Data.List.NonEmpty (NonEmpty) @@ -85,16 +84,16 @@ data NtfEnv = NtfEnv newNtfServerEnv :: NtfServerConfig -> IO NtfEnv newNtfServerEnv config@NtfServerConfig {subQSize, pushQSize, smpAgentCfg, apnsConfig, storeLogFile, caCertificateFile, certificateFile, privateKeyFile, transportConfig} = do - random <- liftIO C.newRandom - store <- atomically newNtfStore + random <- C.newRandom + store <- newNtfStore logInfo "restoring subscriptions..." - storeLog <- liftIO $ mapM (`readWriteNtfStore` store) storeLogFile + storeLog <- mapM (`readWriteNtfStore` store) storeLogFile logInfo "restored subscriptions" - subscriber <- atomically $ newNtfSubscriber subQSize smpAgentCfg random - pushServer <- atomically $ newNtfPushServer pushQSize apnsConfig - tlsServerParams <- liftIO $ loadTLSServerParams caCertificateFile certificateFile privateKeyFile (alpn transportConfig) - Fingerprint fp <- liftIO $ loadFingerprint caCertificateFile - serverStats <- atomically . newNtfServerStats =<< liftIO getCurrentTime + subscriber <- newNtfSubscriber subQSize smpAgentCfg random + pushServer <- newNtfPushServer pushQSize apnsConfig + tlsServerParams <- loadTLSServerParams caCertificateFile certificateFile privateKeyFile (alpn transportConfig) + Fingerprint fp <- loadFingerprint caCertificateFile + serverStats <- newNtfServerStats =<< getCurrentTime pure NtfEnv {config, subscriber, pushServer, store, storeLog, random, tlsServerParams, serverIdentity = C.KeyHash fp, serverStats} data NtfSubscriber = NtfSubscriber @@ -103,10 +102,10 @@ data NtfSubscriber = NtfSubscriber smpAgent :: SMPClientAgent } -newNtfSubscriber :: Natural -> SMPClientAgentConfig -> TVar ChaChaDRG -> STM NtfSubscriber +newNtfSubscriber :: Natural -> SMPClientAgentConfig -> TVar ChaChaDRG -> IO NtfSubscriber newNtfSubscriber qSize smpAgentCfg random = do - smpSubscribers <- TM.empty - newSubQ <- newTBQueue qSize + smpSubscribers <- TM.emptyIO + newSubQ <- newTBQueueIO qSize smpAgent <- newSMPClientAgent smpAgentCfg random pure NtfSubscriber {smpSubscribers, newSubQ, smpAgent} @@ -115,10 +114,10 @@ data SMPSubscriber = SMPSubscriber subThreadId :: TVar (Maybe (Weak ThreadId)) } -newSMPSubscriber :: STM SMPSubscriber +newSMPSubscriber :: IO SMPSubscriber newSMPSubscriber = do - newSubQ <- newTQueue - subThreadId <- newTVar Nothing + newSubQ <- newTQueueIO + subThreadId <- newTVarIO Nothing pure SMPSubscriber {newSubQ, subThreadId} data NtfPushServer = NtfPushServer @@ -134,11 +133,11 @@ data IntervalNotifier = IntervalNotifier interval :: Word16 } -newNtfPushServer :: Natural -> APNSPushClientConfig -> STM NtfPushServer +newNtfPushServer :: Natural -> APNSPushClientConfig -> IO NtfPushServer newNtfPushServer qSize apnsConfig = do - pushQ <- newTBQueue qSize - pushClients <- TM.empty - intervalNotifiers <- TM.empty + pushQ <- newTBQueueIO qSize + pushClients <- TM.emptyIO + intervalNotifiers <- TM.emptyIO pure NtfPushServer {pushQ, pushClients, intervalNotifiers, apnsConfig} newPushClient :: NtfPushServer -> PushProvider -> IO PushProviderClient @@ -151,7 +150,7 @@ newPushClient NtfPushServer {apnsConfig, pushClients} pp = do getPushClient :: NtfPushServer -> PushProvider -> IO PushProviderClient getPushClient s@NtfPushServer {pushClients} pp = - atomically (TM.lookup pp pushClients) >>= maybe (newPushClient s pp) pure + TM.lookupIO pp pushClients >>= maybe (newPushClient s pp) pure data NtfRequest = NtfReqNew CorrId ANewNtfEntity @@ -167,11 +166,11 @@ data NtfServerClient = NtfServerClient sndActiveAt :: TVar SystemTime } -newNtfServerClient :: Natural -> THandleParams NTFVersion 'TServer -> SystemTime -> STM NtfServerClient +newNtfServerClient :: Natural -> THandleParams NTFVersion 'TServer -> SystemTime -> IO NtfServerClient newNtfServerClient qSize ntfThParams ts = do - rcvQ <- newTBQueue qSize - sndQ <- newTBQueue qSize - connected <- newTVar True - rcvActiveAt <- newTVar ts - sndActiveAt <- newTVar ts + rcvQ <- newTBQueueIO qSize + sndQ <- newTBQueueIO qSize + connected <- newTVarIO True + rcvActiveAt <- newTVarIO ts + sndActiveAt <- newTVarIO ts return NtfServerClient {rcvQ, sndQ, ntfThParams, connected, rcvActiveAt, sndActiveAt} diff --git a/src/Simplex/Messaging/Notifications/Server/Stats.hs b/src/Simplex/Messaging/Notifications/Server/Stats.hs index 7debc1ac9..b73e6098f 100644 --- a/src/Simplex/Messaging/Notifications/Server/Stats.hs +++ b/src/Simplex/Messaging/Notifications/Server/Stats.hs @@ -40,30 +40,30 @@ data NtfServerStatsData = NtfServerStatsData _activeSubs :: PeriodStatsData NotifierId } -newNtfServerStats :: UTCTime -> STM NtfServerStats +newNtfServerStats :: UTCTime -> IO NtfServerStats newNtfServerStats ts = do - fromTime <- newTVar ts - tknCreated <- newTVar 0 - tknVerified <- newTVar 0 - tknDeleted <- newTVar 0 - subCreated <- newTVar 0 - subDeleted <- newTVar 0 - ntfReceived <- newTVar 0 - ntfDelivered <- newTVar 0 + fromTime <- newTVarIO ts + tknCreated <- newTVarIO 0 + tknVerified <- newTVarIO 0 + tknDeleted <- newTVarIO 0 + subCreated <- newTVarIO 0 + subDeleted <- newTVarIO 0 + ntfReceived <- newTVarIO 0 + ntfDelivered <- newTVarIO 0 activeTokens <- newPeriodStats activeSubs <- newPeriodStats pure NtfServerStats {fromTime, tknCreated, tknVerified, tknDeleted, subCreated, subDeleted, ntfReceived, ntfDelivered, activeTokens, activeSubs} -getNtfServerStatsData :: NtfServerStats -> STM NtfServerStatsData +getNtfServerStatsData :: NtfServerStats -> IO NtfServerStatsData getNtfServerStatsData s@NtfServerStats {fromTime} = do - _fromTime <- readTVar fromTime - _tknCreated <- readTVar $ tknCreated s - _tknVerified <- readTVar $ tknVerified s - _tknDeleted <- readTVar $ tknDeleted s - _subCreated <- readTVar $ subCreated s - _subDeleted <- readTVar $ subDeleted s - _ntfReceived <- readTVar $ ntfReceived s - _ntfDelivered <- readTVar $ ntfDelivered s + _fromTime <- readTVarIO fromTime + _tknCreated <- readTVarIO $ tknCreated s + _tknVerified <- readTVarIO $ tknVerified s + _tknDeleted <- readTVarIO $ tknDeleted s + _subCreated <- readTVarIO $ subCreated s + _subDeleted <- readTVarIO $ subDeleted s + _ntfReceived <- readTVarIO $ ntfReceived s + _ntfDelivered <- readTVarIO $ ntfDelivered s _activeTokens <- getPeriodStatsData $ activeTokens s _activeSubs <- getPeriodStatsData $ activeSubs s pure NtfServerStatsData {_fromTime, _tknCreated, _tknVerified, _tknDeleted, _subCreated, _subDeleted, _ntfReceived, _ntfDelivered, _activeTokens, _activeSubs} diff --git a/src/Simplex/Messaging/Notifications/Server/Store.hs b/src/Simplex/Messaging/Notifications/Server/Store.hs index 83dc1a4c2..b4d91dc88 100644 --- a/src/Simplex/Messaging/Notifications/Server/Store.hs +++ b/src/Simplex/Messaging/Notifications/Server/Store.hs @@ -33,13 +33,13 @@ data NtfStore = NtfStore subscriptionLookup :: TMap SMPQueueNtf NtfSubscriptionId } -newNtfStore :: STM NtfStore +newNtfStore :: IO NtfStore newNtfStore = do - tokens <- TM.empty - tokenRegistrations <- TM.empty - subscriptions <- TM.empty - tokenSubscriptions <- TM.empty - subscriptionLookup <- TM.empty + tokens <- TM.emptyIO + tokenRegistrations <- TM.emptyIO + subscriptions <- TM.emptyIO + tokenSubscriptions <- TM.emptyIO + subscriptionLookup <- TM.emptyIO pure NtfStore {tokens, tokenRegistrations, subscriptions, tokenSubscriptions, subscriptionLookup} data NtfTknData = NtfTknData @@ -77,6 +77,9 @@ data NtfEntityRec (e :: NtfEntity) where getNtfToken :: NtfStore -> NtfTokenId -> STM (Maybe NtfTknData) getNtfToken st tknId = TM.lookup tknId (tokens st) +getNtfTokenIO :: NtfStore -> NtfTokenId -> IO (Maybe NtfTknData) +getNtfTokenIO st tknId = TM.lookupIO tknId (tokens st) + addNtfToken :: NtfStore -> NtfTokenId -> NtfTknData -> STM () addNtfToken st tknId tkn@NtfTknData {token, tknVerifyKey} = do TM.insert tknId tkn $ tokens st diff --git a/src/Simplex/Messaging/Notifications/Types.hs b/src/Simplex/Messaging/Notifications/Types.hs index 4465f8767..8fcedab53 100644 --- a/src/Simplex/Messaging/Notifications/Types.hs +++ b/src/Simplex/Messaging/Notifications/Types.hs @@ -11,7 +11,7 @@ import Data.Text.Encoding (decodeLatin1, encodeUtf8) import Data.Time (UTCTime) import Database.SQLite.Simple.FromField (FromField (..)) import Database.SQLite.Simple.ToField (ToField (..)) -import Simplex.Messaging.Agent.Protocol (ConnId, NotificationsMode (..)) +import Simplex.Messaging.Agent.Protocol (ConnId, NotificationsMode (..), UserId) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding import Simplex.Messaging.Notifications.Protocol @@ -48,6 +48,7 @@ data NtfToken = NtfToken ntfServer :: NtfServer, ntfTokenId :: Maybe NtfTokenId, -- TODO combine keys to key pair as the types should match + -- | key used by the ntf server to verify transmissions ntfPubKey :: C.APublicAuthKey, -- | key used by the ntf client to sign transmissions @@ -79,17 +80,17 @@ newNtfToken deviceToken ntfServer (ntfPubKey, ntfPrivKey) ntfDhKeys ntfMode = ntfMode } -data NtfSubAction = NtfSubNTFAction NtfSubNTFAction | NtfSubSMPAction NtfSubSMPAction +data NtfSubAction = NSANtf NtfSubNTFAction | NSASMP NtfSubSMPAction deriving (Show) isDeleteNtfSubAction :: NtfSubAction -> Bool isDeleteNtfSubAction = \case - NtfSubNTFAction a -> case a of + NSANtf a -> case a of NSACreate -> False NSACheck -> False NSADelete -> True NSARotate -> True - NtfSubSMPAction a -> case a of + NSASMP a -> case a of NSASmpKey -> False NSASmpDelete -> True @@ -177,7 +178,8 @@ instance FromField NtfAgentSubStatus where fromField = fromTextField_ $ either ( instance ToField NtfAgentSubStatus where toField = toField . decodeLatin1 . smpEncode data NtfSubscription = NtfSubscription - { connId :: ConnId, + { userId :: UserId, + connId :: ConnId, smpServer :: SMPServer, ntfQueueId :: Maybe NotifierId, ntfServer :: NtfServer, @@ -186,10 +188,11 @@ data NtfSubscription = NtfSubscription } deriving (Show) -newNtfSubscription :: ConnId -> SMPServer -> Maybe NotifierId -> NtfServer -> NtfAgentSubStatus -> NtfSubscription -newNtfSubscription connId smpServer ntfQueueId ntfServer ntfSubStatus = +newNtfSubscription :: UserId -> ConnId -> SMPServer -> Maybe NotifierId -> NtfServer -> NtfAgentSubStatus -> NtfSubscription +newNtfSubscription userId connId smpServer ntfQueueId ntfServer ntfSubStatus = NtfSubscription - { connId, + { userId, + connId, smpServer, ntfQueueId, ntfServer, diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index d88b2349a..c5d067475 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -37,6 +37,7 @@ module Simplex.Messaging.Server ) where +import Control.Concurrent.STM.TQueue (flushTQueue) import Control.Logger.Simple import Control.Monad import Control.Monad.Except @@ -47,6 +48,7 @@ import Crypto.Random import Control.Monad.STM (retry) import Data.Bifunctor (first) import Data.ByteString.Base64 (encode) +import qualified Data.ByteString.Builder as BLD import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import qualified Data.ByteString.Lazy.Char8 as LB @@ -54,6 +56,7 @@ import Data.Either (fromRight, partitionEithers) import Data.Functor (($>)) import Data.Int (Int64) import qualified Data.IntMap.Strict as IM +import qualified Data.IntSet as IS import Data.List (intercalate, mapAccumR) import Data.List.NonEmpty (NonEmpty (..)) import qualified Data.List.NonEmpty as L @@ -69,6 +72,7 @@ import Data.Type.Equality import GHC.Stats (getRTSStats) import GHC.TypeLits (KnownNat) import Network.Socket (ServiceName, Socket, socketToHandle) +import Numeric.Natural (Natural) import Simplex.Messaging.Agent.Lock import Simplex.Messaging.Client (ProtocolClient (thParams), ProtocolClientError (..), SMPClient, SMPClientError, forwardSMPTransmission, smpProxyError, temporaryClientError) import Simplex.Messaging.Client.Agent (OwnServer, SMPClientAgent (..), SMPClientAgentEvent (..), closeSMPClientAgent, getSMPServerClient'', isOwnServer, lookupSMPServerClient, getConnectedSMPServerClient) @@ -158,28 +162,33 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do forall s. Server -> String -> - (Server -> TQueue (QueueId, Client)) -> + (Server -> TQueue (QueueId, Client, Subscribed)) -> (Server -> TMap QueueId Client) -> (Client -> TMap QueueId s) -> (s -> IO ()) -> M () serverThread s label subQ subs clientSubs unsub = do labelMyThread label + cls <- asks clients forever $ - atomically updateSubscribers + atomically (updateSubscribers cls) $>>= endPreviousSubscriptions >>= liftIO . mapM_ unsub where - updateSubscribers :: STM (Maybe (QueueId, Client)) - updateSubscribers = do - (qId, clnt) <- readTQueue $ subQ s - let clientToBeNotified c' = - if sameClientId clnt c' - then pure Nothing - else do + updateSubscribers :: TVar (IM.IntMap Client) -> STM (Maybe (QueueId, Client)) + updateSubscribers cls = do + (qId, clnt, subscribed) <- readTQueue $ subQ s + current <- IM.member (clientId clnt) <$> readTVar cls + let updateSub + | not subscribed = TM.lookupDelete + | not current = TM.lookup -- do not insert client if it is already disconnected, but send END to any other client + | otherwise = (`TM.lookupInsert` clnt) -- insert subscribed and current client + clientToBeNotified c' + | sameClientId clnt c' = pure Nothing + | otherwise = do yes <- readTVar $ connected c' pure $ if yes then Just (qId, c') else Nothing - TM.lookupInsert qId clnt (subs s) $>>= clientToBeNotified + updateSub qId (subs s) $>>= clientToBeNotified endPreviousSubscriptions :: (QueueId, Client) -> M (Maybe s) endPreviousSubscriptions (qId, c) = do forkClient c (label <> ".endPreviousSubscriptions") $ @@ -229,7 +238,9 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do initialDelay <- (startAt -) . fromIntegral . (`div` 1000000_000000) . diffTimeToPicoseconds . utctDayTime <$> liftIO getCurrentTime liftIO $ putStrLn $ "server stats log enabled: " <> statsFilePath liftIO $ threadDelay' $ 1000000 * (initialDelay + if initialDelay < 0 then 86400 else 0) - ss@ServerStats {fromTime, qCreated, qSecured, qDeletedAll, qDeletedNew, qDeletedSecured, qSub, qSubAuth, qSubDuplicate, qSubProhibited, msgSent, msgSentAuth, msgSentQuota, msgSentLarge, msgRecv, msgExpired, activeQueues, msgSentNtf, msgRecvNtf, activeQueuesNtf, qCount, msgCount, pRelays, pRelaysOwn, pMsgFwds, pMsgFwdsOwn, pMsgFwdsRecv} <- asks serverStats + ss@ServerStats {fromTime, qCreated, qSecured, qDeletedAll, qDeletedNew, qDeletedSecured, qSub, qSubNoMsg, qSubAuth, qSubDuplicate, qSubProhibited, ntfCreated, ntfDeleted, ntfSub, ntfSubAuth, ntfSubDuplicate, msgSent, msgSentAuth, msgSentQuota, msgSentLarge, msgRecv, msgRecvGet, msgGet, msgGetNoMsg, msgGetAuth, msgGetDuplicate, msgGetProhibited, msgExpired, activeQueues, subscribedQueues, msgSentNtf, msgRecvNtf, activeQueuesNtf, qCount, msgCount, pRelays, pRelaysOwn, pMsgFwds, pMsgFwdsOwn, pMsgFwdsRecv} + <- asks serverStats + QueueStore {queues, notifiers} <- asks queueStore let interval = 1000000 * logInterval forever $ do withFile statsFilePath AppendMode $ \h -> liftIO $ do @@ -242,16 +253,29 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do qDeletedNew' <- atomically $ swapTVar qDeletedNew 0 qDeletedSecured' <- atomically $ swapTVar qDeletedSecured 0 qSub' <- atomically $ swapTVar qSub 0 + qSubNoMsg' <- atomically $ swapTVar qSubNoMsg 0 qSubAuth' <- atomically $ swapTVar qSubAuth 0 qSubDuplicate' <- atomically $ swapTVar qSubDuplicate 0 qSubProhibited' <- atomically $ swapTVar qSubProhibited 0 + ntfCreated' <- atomically $ swapTVar ntfCreated 0 + ntfDeleted' <- atomically $ swapTVar ntfDeleted 0 + ntfSub' <- atomically $ swapTVar ntfSub 0 + ntfSubAuth' <- atomically $ swapTVar ntfSubAuth 0 + ntfSubDuplicate' <- atomically $ swapTVar ntfSubDuplicate 0 msgSent' <- atomically $ swapTVar msgSent 0 msgSentAuth' <- atomically $ swapTVar msgSentAuth 0 msgSentQuota' <- atomically $ swapTVar msgSentQuota 0 msgSentLarge' <- atomically $ swapTVar msgSentLarge 0 msgRecv' <- atomically $ swapTVar msgRecv 0 + msgRecvGet' <- atomically $ swapTVar msgRecvGet 0 + msgGet' <- atomically $ swapTVar msgGet 0 + msgGetNoMsg' <- atomically $ swapTVar msgGetNoMsg 0 + msgGetAuth' <- atomically $ swapTVar msgGetAuth 0 + msgGetDuplicate' <- atomically $ swapTVar msgGetDuplicate 0 + msgGetProhibited' <- atomically $ swapTVar msgGetProhibited 0 msgExpired' <- atomically $ swapTVar msgExpired 0 ps <- atomically $ periodStatCounts activeQueues ts + psSub <- atomically $ periodStatCounts subscribedQueues ts msgSentNtf' <- atomically $ swapTVar msgSentNtf 0 msgRecvNtf' <- atomically $ swapTVar msgRecvNtf 0 psNtf <- atomically $ periodStatCounts activeQueuesNtf ts @@ -264,6 +288,8 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do pMsgFwdsOwn' <- atomically $ getResetProxyStatsData pMsgFwdsOwn pMsgFwdsRecv' <- atomically $ swapTVar pMsgFwdsRecv 0 qCount' <- readTVarIO qCount + qCount'' <- M.size <$> readTVarIO queues + ntfCount' <- M.size <$> readTVarIO notifiers msgCount' <- readTVarIO msgCount hPutStrLn h $ intercalate @@ -302,7 +328,24 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do show msgSentLarge', show msgNtfs', show msgNtfNoSub', - show msgNtfLost' + show msgNtfLost', + show qSubNoMsg', + show msgRecvGet', + show msgGet', + show msgGetNoMsg', + show msgGetAuth', + show msgGetDuplicate', + show msgGetProhibited', + dayCount psSub, + weekCount psSub, + monthCount psSub, + show qCount'', + show ntfCreated', + show ntfDeleted', + show ntfSub', + show ntfSubAuth', + show ntfSubDuplicate', + show ntfCount' ] ) liftIO $ threadDelay' interval @@ -379,21 +422,33 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do hPutStrLn h . B.unpack $ B.intercalate "," [bshow cid, encode sessionId, connected', strEncode createdAt, rcvActiveAt', sndActiveAt', bshow age, subscriptions'] CPStats -> withUserRole $ do ss <- unliftIO u $ asks serverStats - let putStat :: Show a => ByteString -> (ServerStats -> TVar a) -> IO () - putStat label var = readTVarIO (var ss) >>= \v -> B.hPutStr h $ label <> ": " <> bshow v <> "\n" - putProxyStat :: ByteString -> (ServerStats -> ProxyStats) -> IO () + let getStat :: (ServerStats -> TVar a) -> IO a + getStat var = readTVarIO (var ss) + putStat :: Show a => String -> (ServerStats -> TVar a) -> IO () + putStat label var = getStat var >>= \v -> hPutStrLn h $ label <> ": " <> show v + putProxyStat :: String -> (ServerStats -> ProxyStats) -> IO () putProxyStat label var = do - ProxyStatsData {_pRequests, _pSuccesses, _pErrorsConnect, _pErrorsCompat, _pErrorsOther} <- atomically $ getProxyStatsData $ var ss - B.hPutStr h $ label <> ": requests=" <> bshow _pRequests <> ", successes=" <> bshow _pSuccesses <> ", errorsConnect=" <> bshow _pErrorsConnect <> ", errorsCompat=" <> bshow _pErrorsCompat <> ", errorsOther=" <> bshow _pErrorsOther <> "\n" + ProxyStatsData {_pRequests, _pSuccesses, _pErrorsConnect, _pErrorsCompat, _pErrorsOther} <- getProxyStatsData $ var ss + hPutStrLn h $ label <> ": requests=" <> show _pRequests <> ", successes=" <> show _pSuccesses <> ", errorsConnect=" <> show _pErrorsConnect <> ", errorsCompat=" <> show _pErrorsCompat <> ", errorsOther=" <> show _pErrorsOther putStat "fromTime" fromTime putStat "qCreated" qCreated putStat "qSecured" qSecured putStat "qDeletedAll" qDeletedAll putStat "qDeletedNew" qDeletedNew putStat "qDeletedSecured" qDeletedSecured - readTVarIO (day $ activeQueues ss) >>= \v -> B.hPutStr h $ "dayMsgQueues" <> ": " <> bshow (S.size v) <> "\n" + getStat (day . activeQueues) >>= \v -> hPutStrLn h $ "daily active queues: " <> show (S.size v) + getStat (day . subscribedQueues) >>= \v -> hPutStrLn h $ "daily subscribed queues: " <> show (S.size v) + putStat "qSub" qSub + putStat "qSubNoMsg" qSubNoMsg + subs <- (,,) <$> getStat qSubAuth <*> getStat qSubDuplicate <*> getStat qSubProhibited + hPutStrLn h $ "other SUB events (auth, duplicate, prohibited): " <> show subs putStat "msgSent" msgSent putStat "msgRecv" msgRecv + putStat "msgRecvGet" msgRecvGet + putStat "msgGet" msgGet + putStat "msgGetNoMsg" msgGet + gets <- (,,) <$> getStat msgGetAuth <*> getStat msgGetDuplicate <*> getStat msgGetProhibited + hPutStrLn h $ "other GET events (auth, duplicate, prohibited): " <> show gets putStat "msgSentNtf" msgSentNtf putStat "msgRecvNtf" msgRecvNtf putStat "qCount" qCount @@ -417,7 +472,7 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do #endif CPSockets -> withUserRole $ do (accepted', closed', active') <- unliftIO u $ asks sockets - (accepted, closed, active) <- atomically $ (,,) <$> readTVar accepted' <*> readTVar closed' <*> readTVar active' + (accepted, closed, active) <- (,,) <$> readTVarIO accepted' <*> readTVarIO closed' <*> readTVarIO active' hPutStrLn h "Sockets: " hPutStrLn h $ "accepted: " <> show accepted hPutStrLn h $ "closed: " <> show closed @@ -452,28 +507,77 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do activeClients <- readTVarIO clients hPutStrLn h $ "Clients: " <> show (IM.size activeClients) when (r == CPRAdmin) $ do - (smpSubCnt, smpClCnt) <- countClientSubs subscriptions activeClients - (ntfSubCnt, ntfClCnt) <- countClientSubs ntfSubscriptions activeClients - hPutStrLn h $ "SMP subscriptions (via clients, slow): " <> show smpSubCnt - hPutStrLn h $ "SMP subscribed clients (via clients, slow): " <> show smpClCnt - hPutStrLn h $ "Ntf subscriptions (via clients, slow): " <> show ntfSubCnt - hPutStrLn h $ "Ntf subscribed clients (via clients, slow): " <> show ntfClCnt - activeSubs <- readTVarIO subscribers - activeNtfSubs <- readTVarIO notifiers - hPutStrLn h $ "SMP subscriptions: " <> show (M.size activeSubs) - hPutStrLn h $ "SMP subscribed clients: " <> show (countSubClients activeSubs) - hPutStrLn h $ "Ntf subscriptions: " <> show (M.size activeNtfSubs) - hPutStrLn h $ "Ntf subscribed clients: " <> show (countSubClients activeNtfSubs) + clQs <- clientTBQueueLengths activeClients + hPutStrLn h $ "Client queues (rcvQ, sndQ, msgQ): " <> show clQs + (smpSubCnt, smpSubCntByGroup, smpClCnt, smpClQs) <- countClientSubs subscriptions (Just countSMPSubs) activeClients + hPutStrLn h $ "SMP subscriptions (via clients): " <> show smpSubCnt + hPutStrLn h $ "SMP subscriptions (by group: NoSub, SubPending, SubThread, ProhibitSub): " <> show smpSubCntByGroup + hPutStrLn h $ "SMP subscribed clients (via clients): " <> show smpClCnt + hPutStrLn h $ "SMP subscribed clients queues (via clients, rcvQ, sndQ, msgQ): " <> show smpClQs + (ntfSubCnt, _, ntfClCnt, ntfClQs) <- countClientSubs ntfSubscriptions Nothing activeClients + hPutStrLn h $ "Ntf subscriptions (via clients): " <> show ntfSubCnt + hPutStrLn h $ "Ntf subscribed clients (via clients): " <> show ntfClCnt + hPutStrLn h $ "Ntf subscribed clients queues (via clients, rcvQ, sndQ, msgQ): " <> show ntfClQs + putActiveClientsInfo "SMP" subscribers + putActiveClientsInfo "Ntf" notifiers where - countClientSubs :: (Client -> TMap QueueId a) -> IM.IntMap Client -> IO (Int, Int) - countClientSubs subSel = foldM addSubs (0, 0) + putActiveClientsInfo :: String -> TMap QueueId Client -> IO () + putActiveClientsInfo protoName clients = do + activeSubs <- readTVarIO clients + hPutStrLn h $ protoName <> " subscriptions: " <> show (M.size activeSubs) + clCnt <- if r == CPRAdmin then putClientQueues activeSubs else pure $ countSubClients activeSubs + hPutStrLn h $ protoName <> " subscribed clients: " <> show clCnt where - addSubs :: (Int, Int) -> Client -> IO (Int, Int) - addSubs (subCnt, clCnt) cl = do + putClientQueues :: M.Map QueueId Client -> IO Int + putClientQueues subs = do + let cls = differentClients subs + clQs <- clientTBQueueLengths cls + hPutStrLn h $ protoName <> " subscribed clients queues (rcvQ, sndQ, msgQ): " <> show clQs + pure $ length cls + differentClients :: M.Map QueueId Client -> [Client] + differentClients = fst . M.foldl' addClient ([], IS.empty) + where + addClient acc@(cls, clSet) cl@Client {clientId} + | IS.member clientId clSet = acc + | otherwise = (cl : cls, IS.insert clientId clSet) + countSubClients :: M.Map QueueId Client -> Int + countSubClients = IS.size . M.foldr' (IS.insert . clientId) IS.empty + countClientSubs :: (Client -> TMap QueueId a) -> Maybe (M.Map QueueId a -> IO (Int, Int, Int, Int)) -> IM.IntMap Client -> IO (Int, (Int, Int, Int, Int), Int, (Natural, Natural, Natural)) + countClientSubs subSel countSubs_ = foldM addSubs (0, (0, 0, 0, 0), 0, (0, 0, 0)) + where + addSubs :: (Int, (Int, Int, Int, Int), Int, (Natural, Natural, Natural)) -> Client -> IO (Int, (Int, Int, Int, Int), Int, (Natural, Natural, Natural)) + addSubs (!subCnt, cnts@(!c1, !c2, !c3, !c4), !clCnt, !qs) cl = do subs <- readTVarIO $ subSel cl + cnts' <- case countSubs_ of + Nothing -> pure cnts + Just countSubs -> do + (c1', c2', c3', c4') <- countSubs subs + pure (c1 + c1', c2 + c2', c3 + c3', c4 + c4') let cnt = M.size subs - pure (subCnt + cnt, clCnt + if cnt == 0 then 0 else 1) - countSubClients = S.size . M.foldr' (S.insert . clientId) S.empty + clCnt' = if cnt == 0 then clCnt else clCnt + 1 + qs' <- if cnt == 0 then pure qs else addQueueLengths qs cl + pure (subCnt + cnt, cnts', clCnt', qs') + clientTBQueueLengths :: Foldable t => t Client -> IO (Natural, Natural, Natural) + clientTBQueueLengths = foldM addQueueLengths (0, 0, 0) + addQueueLengths (!rl, !sl, !ml) cl = do + (rl', sl', ml') <- queueLengths cl + pure (rl + rl', sl + sl', ml + ml') + queueLengths Client {rcvQ, sndQ, msgQ} = do + rl <- atomically $ lengthTBQueue rcvQ + sl <- atomically $ lengthTBQueue sndQ + ml <- atomically $ lengthTBQueue msgQ + pure (rl, sl, ml) + countSMPSubs :: M.Map QueueId Sub -> IO (Int, Int, Int, Int) + countSMPSubs = foldM countSubs (0, 0, 0, 0) + where + countSubs (c1, c2, c3, c4) Sub {subThread} = case subThread of + ServerSub t -> do + st <- readTVarIO t + pure $ case st of + NoSub -> (c1 + 1, c2, c3, c4) + SubPending -> (c1, c2 + 1, c3, c4) + SubThread _ -> (c1, c2, c3 + 1, c4) + ProhibitSub -> pure (c1, c2, c3, c4 + 1) CPDelete queueId' -> withUserRole $ unliftIO u $ do st <- asks queueStore ms <- asks msgStore @@ -515,10 +619,8 @@ runClientTransport h@THandle {params = thParams@THandleParams {thVersion, sessio ts <- liftIO getSystemTime active <- asks clients nextClientId <- asks clientSeq - c <- atomically $ do - new@Client {clientId} <- newClient nextClientId q thVersion sessionId ts - modifyTVar' active $ IM.insert clientId new - pure new + c@Client {clientId} <- liftIO $ newClient nextClientId q thVersion sessionId ts + atomically $ modifyTVar' active $ IM.insert clientId c s <- asks server expCfg <- asks $ inactiveClientExpiration . config th <- newMVar h -- put TH under a fair lock to interleave messages and command responses @@ -528,22 +630,26 @@ runClientTransport h@THandle {params = thParams@THandleParams {thVersion, sessio where disconnectThread_ c (Just expCfg) = [liftIO $ disconnectTransport h (rcvActiveAt c) (sndActiveAt c) expCfg (noSubscriptions c)] disconnectThread_ _ _ = [] - noSubscriptions c = atomically $ (&&) <$> TM.null (subscriptions c) <*> TM.null (ntfSubscriptions c) + noSubscriptions c = atomically $ (&&) <$> TM.null (ntfSubscriptions c) <*> (not . hasSubs <$> readTVar (subscriptions c)) + hasSubs = any $ (\case ServerSub _ -> True; ProhibitSub -> False) . subThread clientDisconnected :: Client -> M () -clientDisconnected c@Client {clientId, subscriptions, connected, sessionId, endThreads} = do +clientDisconnected c@Client {clientId, subscriptions, ntfSubscriptions, connected, sessionId, endThreads} = do labelMyThread . B.unpack $ "client $" <> encode sessionId <> " disc" - subs <- atomically $ do + (subs, ntfSubs) <- atomically $ do writeTVar connected False - swapTVar subscriptions M.empty + (,) <$> swapTVar subscriptions M.empty <*> swapTVar ntfSubscriptions M.empty liftIO $ mapM_ cancelSub subs - srvSubs <- asks $ subscribers . server - atomically $ modifyTVar' srvSubs $ \cs -> - M.foldrWithKey (\sub _ -> M.update deleteCurrentClient sub) cs subs + Server {subscribers, notifiers} <- asks server + updateSubscribers subs subscribers + updateSubscribers ntfSubs notifiers asks clients >>= atomically . (`modifyTVar'` IM.delete clientId) tIds <- atomically $ swapTVar endThreads IM.empty liftIO $ mapM_ (mapM_ killThread <=< deRefWeak) tIds where + updateSubscribers subs srvSubs = do + atomically $ modifyTVar' srvSubs $ \cs -> + M.foldrWithKey (\sub _ -> M.update deleteCurrentClient sub) cs subs deleteCurrentClient :: Client -> Maybe Client deleteCurrentClient c' | sameClientId c c' = Nothing @@ -553,10 +659,12 @@ sameClientId :: Client -> Client -> Bool sameClientId Client {clientId} Client {clientId = cId'} = clientId == cId' cancelSub :: Sub -> IO () -cancelSub s = - readTVarIO (subThread s) >>= \case - SubThread t -> liftIO $ deRefWeak t >>= mapM_ killThread - _ -> pure () +cancelSub s = case subThread s of + ServerSub st -> + readTVarIO st >>= \case + SubThread t -> liftIO $ deRefWeak t >>= mapM_ killThread + _ -> pure () + ProhibitSub -> pure () receive :: Transport c => THandleSMP c 'TServer -> Client -> M () receive h@THandle {params = THandleParams {thAuth}} Client {rcvQ, sndQ, rcvActiveAt, sessionId} = do @@ -579,8 +687,10 @@ receive h@THandle {params = THandleParams {thAuth}} Client {rcvQ, sndQ, rcvActiv VRVerified qr -> pure $ Right (qr, (corrId, entId, cmd)) VRFailed -> do case cmd of - Cmd _ SEND {} -> atomically $ modifyTVar' (msgSentAuth stats) (+ 1) - Cmd _ SUB -> atomically $ modifyTVar' (qSubAuth stats) (+ 1) + Cmd _ SEND {} -> incStat $ msgSentAuth stats + Cmd _ SUB -> incStat $ qSubAuth stats + Cmd _ NSUB -> incStat $ ntfSubAuth stats + Cmd _ GET -> incStat $ msgGetAuth stats _ -> pure () pure $ Left (corrId, entId, ERR AUTH) write q = mapM_ (atomically . writeTBQueue q) . L.nonEmpty @@ -775,7 +885,7 @@ client thParams' clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessi ProxyAgent {smpAgent = a} <- asks proxyAgent ServerStats {pMsgFwds, pMsgFwdsOwn} <- asks serverStats let inc = mkIncProxyStats pMsgFwds pMsgFwdsOwn - atomically (lookupSMPServerClient a sessId) >>= \case + liftIO (lookupSMPServerClient a sessId) >>= \case Just (own, smp) -> do inc own pRequests if v >= sendingProxySMPVersion @@ -808,13 +918,13 @@ client thParams' clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessi transportErr = PROXY . BROKER . TRANSPORT mkIncProxyStats :: MonadIO m => ProxyStats -> ProxyStats -> OwnServer -> (ProxyStats -> TVar Int) -> m () mkIncProxyStats ps psOwn own sel = do - atomically $ modifyTVar' (sel ps) (+ 1) - when own $ atomically $ modifyTVar' (sel psOwn) (+ 1) + incStat $ sel ps + when own $ incStat $ sel psOwn processCommand :: (Maybe QueueRec, Transmission Cmd) -> M (Maybe (Transmission BrokerMsg)) - processCommand (qr_, (corrId, queueId, cmd)) = case cmd of - Cmd SProxiedClient command -> processProxiedCmd (corrId, queueId, command) + processCommand (qr_, (corrId, entId, cmd)) = case cmd of + Cmd SProxiedClient command -> processProxiedCmd (corrId, entId, command) Cmd SSender command -> Just <$> case command of - SKEY sKey -> (corrId,queueId,) <$> case qr_ of + SKEY sKey -> (corrId,entId,) <$> case qr_ of Just QueueRec {sndSecure, recipientId} | sndSecure -> secureQueue_ "SKEY" recipientId sKey | otherwise -> pure $ ERR AUTH @@ -830,15 +940,15 @@ client thParams' clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessi ifM allowNew (createQueue st rKey dhKey subMode sndSecure) - (pure (corrId, queueId, ERR AUTH)) + (pure (corrId, entId, ERR AUTH)) where allowNew = do ServerConfig {allowNewQueues, newQueueBasicAuth} <- asks config pure $ allowNewQueues && maybe True ((== auth) . Just) newQueueBasicAuth - SUB -> withQueue (`subscribeQueue` queueId) + SUB -> withQueue (`subscribeQueue` entId) GET -> withQueue getMessage ACK msgId -> withQueue (`acknowledgeMsg` msgId) - KEY sKey -> (corrId,queueId,) <$> case qr_ of + KEY sKey -> (corrId,entId,) <$> case qr_ of Just QueueRec {recipientId} -> secureQueue_ "KEY" recipientId sKey Nothing -> pure $ ERR INTERNAL NKEY nKey dhKey -> addQueueNotifier_ st nKey dhKey @@ -863,7 +973,7 @@ client thParams' clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessi status = QueueActive, sndSecure } - (corrId,queueId,) <$> addQueueRetry 3 qik qRec + (corrId,entId,) <$> addQueueRetry 3 qik qRec where addQueueRetry :: Int -> ((RecipientId, SenderId) -> QueueIdsKeys) -> ((RecipientId, SenderId) -> QueueRec) -> M BrokerMsg @@ -878,8 +988,8 @@ client thParams' clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessi Right _ -> do withLog (`logCreateById` rId) stats <- asks serverStats - atomically $ modifyTVar' (qCreated stats) (+ 1) - atomically $ modifyTVar' (qCount stats) (+ 1) + incStat $ qCreated stats + incStat $ qCount stats case subMode of SMOnlyCreate -> pure () SMSubscribe -> void $ subscribeQueue qr rId @@ -901,152 +1011,178 @@ client thParams' clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessi withLog $ \s -> logSecureQueue s rId sKey st <- asks queueStore stats <- asks serverStats - atomically $ modifyTVar' (qSecured stats) (+ 1) + incStat $ qSecured stats atomically $ either ERR (const OK) <$> secureQueue st rId sKey addQueueNotifier_ :: QueueStore -> NtfPublicAuthKey -> RcvNtfPublicDhKey -> M (Transmission BrokerMsg) addQueueNotifier_ st notifierKey dhKey = time "NKEY" $ do (rcvPublicDhKey, privDhKey) <- atomically . C.generateKeyPair =<< asks random let rcvNtfDhSecret = C.dh' dhKey privDhKey - (corrId,queueId,) <$> addNotifierRetry 3 rcvPublicDhKey rcvNtfDhSecret + (corrId,entId,) <$> addNotifierRetry 3 rcvPublicDhKey rcvNtfDhSecret where addNotifierRetry :: Int -> RcvNtfPublicDhKey -> RcvNtfDhSecret -> M BrokerMsg addNotifierRetry 0 _ _ = pure $ ERR INTERNAL addNotifierRetry n rcvPublicDhKey rcvNtfDhSecret = do notifierId <- randomId =<< asks (queueIdBytes . config) let ntfCreds = NtfCreds {notifierId, notifierKey, rcvNtfDhSecret} - atomically (addQueueNotifier st queueId ntfCreds) >>= \case + atomically (addQueueNotifier st entId ntfCreds) >>= \case Left DUPLICATE_ -> addNotifierRetry (n - 1) rcvPublicDhKey rcvNtfDhSecret Left e -> pure $ ERR e Right _ -> do - withLog $ \s -> logAddNotifier s queueId ntfCreds + withLog $ \s -> logAddNotifier s entId ntfCreds + incStat . ntfCreated =<< asks serverStats pure $ NID notifierId rcvPublicDhKey deleteQueueNotifier_ :: QueueStore -> M (Transmission BrokerMsg) deleteQueueNotifier_ st = do - withLog (`logDeleteNotifier` queueId) - okResp <$> atomically (deleteQueueNotifier st queueId) + withLog (`logDeleteNotifier` entId) + atomically (deleteQueueNotifier st entId) >>= \case + Right () -> do + -- Possibly, the same should be done if the queue is suspended, but currently we do not use it + atomically $ writeTQueue ntfSubscribedQ (entId, clnt, False) + incStat . ntfDeleted =<< asks serverStats + pure ok + Left e -> pure $ err e suspendQueue_ :: QueueStore -> M (Transmission BrokerMsg) suspendQueue_ st = do - withLog (`logSuspendQueue` queueId) - okResp <$> atomically (suspendQueue st queueId) + withLog (`logSuspendQueue` entId) + okResp <$> atomically (suspendQueue st entId) subscribeQueue :: QueueRec -> RecipientId -> M (Transmission BrokerMsg) subscribeQueue qr rId = do - stats <- asks serverStats atomically (TM.lookup rId subscriptions) >>= \case - Nothing -> do - atomically $ modifyTVar' (qSub stats) (+ 1) - newSub >>= deliver - Just s@Sub {subThread} -> - readTVarIO subThread >>= \case + Nothing -> newSub >>= deliver True + Just s@Sub {subThread} -> do + stats <- asks serverStats + case subThread of ProhibitSub -> do -- cannot use SUB in the same connection where GET was used - atomically $ modifyTVar' (qSubProhibited stats) (+ 1) + incStat $ qSubProhibited stats pure (corrId, rId, ERR $ CMD PROHIBITED) _ -> do - atomically $ modifyTVar' (qSubDuplicate stats) (+ 1) - atomically (tryTakeTMVar $ delivered s) >> deliver s + incStat $ qSubDuplicate stats + atomically (tryTakeTMVar $ delivered s) >> deliver False s where newSub :: M Sub newSub = time "SUB newSub" . atomically $ do - writeTQueue subscribedQ (rId, clnt) + writeTQueue subscribedQ (rId, clnt, True) sub <- newSubscription NoSub TM.insert rId sub subscriptions pure sub - deliver :: Sub -> M (Transmission BrokerMsg) - deliver sub = do + deliver :: Bool -> Sub -> M (Transmission BrokerMsg) + deliver inc sub = do q <- getStoreMsgQueue "SUB" rId msg_ <- atomically $ tryPeekMsg q + when inc $ do + stats <- asks serverStats + incStat $ (if isJust msg_ then qSub else qSubNoMsg) stats + atomically $ updatePeriodStats (subscribedQueues stats) rId deliverMessage "SUB" qr rId sub msg_ getMessage :: QueueRec -> M (Transmission BrokerMsg) getMessage qr = time "GET" $ do - atomically (TM.lookup queueId subscriptions) >>= \case + atomically (TM.lookup entId subscriptions) >>= \case Nothing -> - atomically newSub >>= getMessage_ + atomically newSub >>= (`getMessage_` Nothing) Just s@Sub {subThread} -> - readTVarIO subThread >>= \case + case subThread of ProhibitSub -> atomically (tryTakeTMVar $ delivered s) - >> getMessage_ s + >>= getMessage_ s -- cannot use GET in the same connection where there is an active subscription - _ -> pure (corrId, queueId, ERR $ CMD PROHIBITED) + _ -> do + stats <- asks serverStats + incStat $ msgGetProhibited stats + pure (corrId, entId, ERR $ CMD PROHIBITED) where newSub :: STM Sub newSub = do - s <- newSubscription ProhibitSub - TM.insert queueId s subscriptions + s <- newProhibitedSub + TM.insert entId s subscriptions pure s - getMessage_ :: Sub -> M (Transmission BrokerMsg) - getMessage_ s = do - q <- getStoreMsgQueue "GET" queueId - atomically $ - tryPeekMsg q >>= \case - Just msg -> - let encMsg = encryptMsg qr msg - in setDelivered s msg $> (corrId, queueId, MSG encMsg) - _ -> pure (corrId, queueId, OK) + getMessage_ :: Sub -> Maybe MsgId -> M (Transmission BrokerMsg) + getMessage_ s delivered_ = do + q <- getStoreMsgQueue "GET" entId + stats <- asks serverStats + (statCnt, r) <- + atomically $ + tryPeekMsg q >>= \case + Just msg -> + let encMsg = encryptMsg qr msg + cnt = if isJust delivered_ then msgGetDuplicate else msgGet + in setDelivered s msg $> (cnt, (corrId, entId, MSG encMsg)) + _ -> pure (msgGetNoMsg, (corrId, entId, OK)) + incStat $ statCnt stats + pure r withQueue :: (QueueRec -> M (Transmission BrokerMsg)) -> M (Transmission BrokerMsg) withQueue action = maybe (pure $ err AUTH) action qr_ subscribeNotifications :: M (Transmission BrokerMsg) - subscribeNotifications = time "NSUB" . atomically $ do - unlessM (TM.member queueId ntfSubscriptions) $ do - writeTQueue ntfSubscribedQ (queueId, clnt) - TM.insert queueId () ntfSubscriptions + subscribeNotifications = do + statCount <- + time "NSUB" . atomically $ do + ifM + (TM.member entId ntfSubscriptions) + (pure ntfSubDuplicate) + (newSub $> ntfSub) + incStat . statCount =<< asks serverStats pure ok + where + newSub = do + writeTQueue ntfSubscribedQ (entId, clnt, True) + TM.insert entId () ntfSubscriptions acknowledgeMsg :: QueueRec -> MsgId -> M (Transmission BrokerMsg) acknowledgeMsg qr msgId = time "ACK" $ do - atomically (TM.lookup queueId subscriptions) >>= \case + liftIO (TM.lookupIO entId subscriptions) >>= \case Nothing -> pure $ err NO_MSG Just sub -> atomically (getDelivered sub) >>= \case Just st -> do - q <- getStoreMsgQueue "ACK" queueId + q <- getStoreMsgQueue "ACK" entId case st of ProhibitSub -> do deletedMsg_ <- atomically $ tryDelMsg q msgId - mapM_ updateStats deletedMsg_ + mapM_ (updateStats True) deletedMsg_ pure ok _ -> do (deletedMsg_, msg_) <- atomically $ tryDelPeekMsg q msgId - mapM_ updateStats deletedMsg_ - deliverMessage "ACK" qr queueId sub msg_ + mapM_ (updateStats False) deletedMsg_ + deliverMessage "ACK" qr entId sub msg_ _ -> pure $ err NO_MSG where - getDelivered :: Sub -> STM (Maybe SubscriptionThread) + getDelivered :: Sub -> STM (Maybe ServerSub) getDelivered Sub {delivered, subThread} = do tryTakeTMVar delivered $>>= \msgId' -> if msgId == msgId' || B.null msgId - then Just <$> readTVar subThread + then pure $ Just subThread else putTMVar delivered msgId' $> Nothing - updateStats :: Message -> M () - updateStats = \case + updateStats :: Bool -> Message -> M () + updateStats isGet = \case MessageQuota {} -> pure () Message {msgFlags} -> do stats <- asks serverStats - atomically $ modifyTVar' (msgRecv stats) (+ 1) + incStat $ msgRecv stats + when isGet $ incStat $ msgRecvGet stats atomically $ modifyTVar' (msgCount stats) (subtract 1) - atomically $ updatePeriodStats (activeQueues stats) queueId + atomically $ updatePeriodStats (activeQueues stats) entId when (notification msgFlags) $ do - atomically $ modifyTVar' (msgRecvNtf stats) (+ 1) - atomically $ updatePeriodStats (activeQueuesNtf stats) queueId + incStat $ msgRecvNtf stats + atomically $ updatePeriodStats (activeQueuesNtf stats) entId sendMessage :: QueueRec -> MsgFlags -> MsgBody -> M (Transmission BrokerMsg) sendMessage qr msgFlags msgBody | B.length msgBody > maxMessageLength thVersion = do stats <- asks serverStats - atomically $ modifyTVar' (msgSentLarge stats) (+ 1) + incStat $ msgSentLarge stats pure $ err LARGE_MSG | otherwise = do stats <- asks serverStats case status qr of QueueOff -> do - atomically $ modifyTVar' (msgSentAuth stats) (+ 1) + incStat $ msgSentAuth stats pure $ err AUTH QueueActive -> case C.maxLenBS msgBody of @@ -1058,7 +1194,7 @@ client thParams' clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessi atomically . writeMsg q =<< mkMessage body case msg_ of Nothing -> do - atomically $ modifyTVar' (msgSentQuota stats) (+ 1) + incStat $ msgSentQuota stats pure $ err QUOTA Just (msg, wasEmpty) -> time "SEND ok" $ do when wasEmpty $ tryDeliverMessage msg @@ -1066,16 +1202,16 @@ client thParams' clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessi forM_ (notifier qr) $ \ntf -> do asks random >>= atomically . trySendNotification ntf msg >>= \case Nothing -> do - atomically $ modifyTVar' (msgNtfNoSub stats) (+ 1) + incStat $ msgNtfNoSub stats logWarn "No notification subscription" Just False -> do - atomically $ modifyTVar' (msgNtfLost stats) (+ 1) + incStat $ msgNtfLost stats logWarn "Dropped message notification" - Just True -> atomically $ modifyTVar' (msgNtfs stats) (+ 1) - atomically $ modifyTVar' (msgSentNtf stats) (+ 1) + Just True -> incStat $ msgNtfs stats + incStat $ msgSentNtf stats atomically $ updatePeriodStats (activeQueuesNtf stats) (recipientId qr) - atomically $ modifyTVar' (msgSent stats) (+ 1) - atomically $ modifyTVar' (msgCount stats) (+ 1) + incStat $ msgSent stats + incStat $ msgCount stats atomically $ updatePeriodStats (activeQueues stats) (recipientId qr) pure ok where @@ -1110,26 +1246,28 @@ client thParams' clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessi deliverToSub = TM.lookup rId subscribers $>>= \rc@Client {subscriptions = subs, sndQ = q} -> TM.lookup rId subs - $>>= \s@Sub {subThread, delivered} -> readTVar subThread >>= \case - NoSub -> - tryTakeTMVar delivered >>= \case - Just _ -> pure Nothing -- if a message was already delivered, should not deliver more - Nothing -> - ifM - (isFullTBQueue q) - (writeTVar subThread SubPending $> Just (rc, s)) - (deliver q s $> Nothing) - _ -> pure Nothing + $>>= \s@Sub {subThread, delivered} -> case subThread of + ProhibitSub -> pure Nothing + ServerSub st -> readTVar st >>= \case + NoSub -> + tryTakeTMVar delivered >>= \case + Just _ -> pure Nothing -- if a message was already delivered, should not deliver more + Nothing -> + ifM + (isFullTBQueue q) + (writeTVar st SubPending $> Just (rc, s, st)) + (deliver q s $> Nothing) + _ -> pure Nothing deliver q s = do let encMsg = encryptMsg qr msg writeTBQueue q [(CorrId "", rId, MSG encMsg)] void $ setDelivered s msg - forkDeliver (rc@Client {sndQ = q}, s@Sub {subThread, delivered}) = do + forkDeliver (rc@Client {sndQ = q}, s@Sub {delivered}, st) = do t <- mkWeakThreadId =<< forkIO deliverThread - atomically . modifyTVar' subThread $ \case + atomically . modifyTVar' st $ \case -- this case is needed because deliverThread can exit before it SubPending -> SubThread t - st -> st + st' -> st' where deliverThread = do labelMyThread $ B.unpack ("client $" <> encode sessionId) <> " deliver/SEND" @@ -1139,7 +1277,7 @@ client thParams' clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessi Just _ -> pure () -- if a message was already delivered, should not deliver more Nothing -> do deliver q s - writeTVar subThread NoSub + writeTVar st NoSub trySendNotification :: NtfCreds -> Message -> TVar ChaChaDRG -> STM (Maybe Bool) trySendNotification NtfCreds {notifierId, rcvNtfDhSecret} msg ntfNonceDrg = @@ -1197,7 +1335,7 @@ client thParams' clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessi let fr = FwdResponse {fwdCorrId, fwdResponse = r2} r3 = EncFwdResponse $ C.cbEncryptNoPad sessSecret (C.reverseNonce proxyNonce) (smpEncode fr) stats <- asks serverStats - atomically $ modifyTVar' (pMsgFwdsRecv stats) (+ 1) + incStat $ pMsgFwdsRecv stats pure $ RRES r3 where rejectOrVerify :: Maybe (THandleAuth 'TServer) -> SignedTransmission ErrorType Cmd -> M (Either (Transmission BrokerMsg) (Maybe QueueRec, Transmission Cmd)) @@ -1217,7 +1355,7 @@ client thParams' clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessi VRFailed -> Left (corrId', entId', ERR AUTH) deliverMessage :: T.Text -> QueueRec -> RecipientId -> Sub -> Maybe Message -> M (Transmission BrokerMsg) deliverMessage name qr rId s@Sub {subThread} msg_ = time (name <> " deliver") . atomically $ - readTVar subThread >>= \case + case subThread of ProhibitSub -> pure resp _ -> case msg_ of Just msg -> @@ -1228,7 +1366,7 @@ client thParams' clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessi resp = (corrId, rId, OK) time :: T.Text -> M a -> M a - time name = timed name queueId + time name = timed name entId encryptMsg :: QueueRec -> Message -> RcvMessage encryptMsg qr msg = encrypt . encodeRcvMsgBody $ case msg of @@ -1251,37 +1389,44 @@ client thParams' clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessi delQueueAndMsgs :: QueueStore -> M (Transmission BrokerMsg) delQueueAndMsgs st = do - withLog (`logDeleteQueue` queueId) + withLog (`logDeleteQueue` entId) ms <- asks msgStore - atomically (deleteQueue st queueId $>>= \q -> delMsgQueue ms queueId $> Right q) >>= \case - Right q -> updateDeletedStats q $> ok + atomically (deleteQueue st entId $>>= \q -> delMsgQueue ms entId $> Right q) >>= \case + Right q -> do + -- Possibly, the same should be done if the queue is suspended, but currently we do not use it + atomically $ writeTQueue subscribedQ (entId, clnt, False) + atomically $ writeTQueue ntfSubscribedQ (entId, clnt, False) + updateDeletedStats q + pure ok Left e -> pure $ err e getQueueInfo :: QueueRec -> M (Transmission BrokerMsg) getQueueInfo QueueRec {senderKey, notifier} = do - q@MsgQueue {size} <- getStoreMsgQueue "getQueueInfo" queueId + q@MsgQueue {size} <- getStoreMsgQueue "getQueueInfo" entId info <- atomically $ do - qiSub <- TM.lookup queueId subscriptions >>= mapM mkQSub + qiSub <- TM.lookup entId subscriptions >>= mapM mkQSub qiSize <- readTVar size qiMsg <- toMsgInfo <$$> tryPeekMsg q pure QueueInfo {qiSnd = isJust senderKey, qiNtf = isJust notifier, qiSub, qiSize, qiMsg} - pure (corrId, queueId, INFO info) + pure (corrId, entId, INFO info) where mkQSub Sub {subThread, delivered} = do - st <- readTVar subThread - let qSubThread = case st of + qSubThread <- case subThread of + ServerSub t -> do + st <- readTVar t + pure $ case st of NoSub -> QNoSub SubPending -> QSubPending SubThread _ -> QSubThread - ProhibitSub -> QProhibitSub + ProhibitSub -> pure QProhibitSub qDelivered <- decodeLatin1 . encode <$$> tryReadTMVar delivered pure QSub {qSubThread, qDelivered} ok :: Transmission BrokerMsg - ok = (corrId, queueId, OK) + ok = (corrId, entId, OK) err :: ErrorType -> Transmission BrokerMsg - err e = (corrId, queueId, ERR e) + err e = (corrId, entId, ERR e) okResp :: Either ErrorType () -> Transmission BrokerMsg okResp = either err $ const ok @@ -1290,9 +1435,13 @@ updateDeletedStats :: QueueRec -> M () updateDeletedStats q = do stats <- asks serverStats let delSel = if isNothing (senderKey q) then qDeletedNew else qDeletedSecured - atomically $ modifyTVar' (delSel stats) (+ 1) - atomically $ modifyTVar' (qDeletedAll stats) (+ 1) - atomically $ modifyTVar' (qCount stats) (subtract 1) + incStat $ delSel stats + incStat $ qDeletedAll stats + incStat $ qCount stats + +incStat :: MonadIO m => TVar Int -> m () +incStat v = atomically $ modifyTVar' v (+ 1) +{-# INLINE incStat #-} withLog :: (StoreLog 'WriteMode -> IO a) -> M () withLog action = do @@ -1321,13 +1470,16 @@ saveServerMessages keepMsgs = asks (storeMsgsFile . config) >>= mapM_ saveMessag logInfo $ "saving messages to file " <> T.pack f ms <- asks msgStore liftIO . withFile f WriteMode $ \h -> - readTVarIO ms >>= mapM_ (saveQueueMsgs ms h) . M.keys + readTVarIO ms >>= mapM_ (saveQueueMsgs h) . M.assocs logInfo "messages saved" where - getMessages = if keepMsgs then snapshotMsgQueue else flushMsgQueue - saveQueueMsgs ms h rId = - atomically (getMessages ms rId) - >>= mapM_ (B.hPutStrLn h . strEncode . MLRv3 rId) + saveQueueMsgs h (rId, q) = BLD.hPutBuilder h . encodeMessages rId =<< atomically (getMessages $ msgQueue q) + getMessages = if keepMsgs then snapshotTQueue else flushTQueue + snapshotTQueue q = do + msgs <- flushTQueue q + mapM_ (writeTQueue q) msgs + pure msgs + encodeMessages rId = mconcat . map (\msg -> BLD.byteString (strEncode $ MLRv3 rId msg) <> BLD.char8 '\n') restoreServerMessages :: M Int restoreServerMessages = @@ -1370,7 +1522,7 @@ restoreServerMessages = saveServerStats :: M () saveServerStats = asks (serverStatsBackupFile . config) - >>= mapM_ (\f -> asks serverStats >>= atomically . getServerStatsData >>= liftIO . saveStats f) + >>= mapM_ (\f -> asks serverStats >>= liftIO . getServerStatsData >>= liftIO . saveStats f) where saveStats f stats = do logInfo $ "saving server stats to file " <> T.pack f diff --git a/src/Simplex/Messaging/Server/Env/STM.hs b/src/Simplex/Messaging/Server/Env/STM.hs index b40e9fc16..84e664607 100644 --- a/src/Simplex/Messaging/Server/Env/STM.hs +++ b/src/Simplex/Messaging/Server/Env/STM.hs @@ -1,13 +1,15 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StrictData #-} module Simplex.Messaging.Server.Env.STM where import Control.Concurrent (ThreadId) -import Control.Monad.IO.Unlift +import Control.Logger.Simple +import Control.Monad import Crypto.Random import Data.ByteString.Char8 (ByteString) import Data.Int (Int64) @@ -17,6 +19,7 @@ import Data.List.NonEmpty (NonEmpty) import Data.Map.Strict (Map) import qualified Data.Map.Strict as M import Data.Maybe (isJust, isNothing) +import qualified Data.Text as T import Data.Time.Clock (getCurrentTime) import Data.Time.Clock.System (SystemTime) import Data.X509.Validation (Fingerprint (..)) @@ -104,7 +107,7 @@ defaultMessageExpiration = defaultInactiveClientExpiration :: ExpirationConfig defaultInactiveClientExpiration = ExpirationConfig - { ttl = 43200, -- seconds, 12 hours + { ttl = 21600, -- seconds, 6 hours checkInterval = 3600 -- seconds, 1 hours } @@ -128,10 +131,12 @@ data Env = Env proxyAgent :: ProxyAgent -- senders served on this proxy } +type Subscribed = Bool + data Server = Server - { subscribedQ :: TQueue (RecipientId, Client), + { subscribedQ :: TQueue (RecipientId, Client, Subscribed), subscribers :: TMap RecipientId Client, - ntfSubscribedQ :: TQueue (NotifierId, Client), + ntfSubscribedQ :: TQueue (NotifierId, Client, Subscribed), notifiers :: TMap NotifierId Client, savingLock :: Lock } @@ -160,68 +165,77 @@ data Client = Client sndActiveAt :: TVar SystemTime } -data SubscriptionThread = NoSub | SubPending | SubThread (Weak ThreadId) | ProhibitSub +data ServerSub = ServerSub (TVar SubscriptionThread) | ProhibitSub + +data SubscriptionThread = NoSub | SubPending | SubThread (Weak ThreadId) data Sub = Sub - { subThread :: TVar SubscriptionThread, + { subThread :: ServerSub, -- Nothing value indicates that sub delivered :: TMVar MsgId } -newServer :: STM Server +newServer :: IO Server newServer = do - subscribedQ <- newTQueue - subscribers <- TM.empty - ntfSubscribedQ <- newTQueue - notifiers <- TM.empty - savingLock <- createLock + subscribedQ <- newTQueueIO + subscribers <- TM.emptyIO + ntfSubscribedQ <- newTQueueIO + notifiers <- TM.emptyIO + savingLock <- atomically createLock return Server {subscribedQ, subscribers, ntfSubscribedQ, notifiers, savingLock} -newClient :: TVar ClientId -> Natural -> VersionSMP -> ByteString -> SystemTime -> STM Client +newClient :: TVar ClientId -> Natural -> VersionSMP -> ByteString -> SystemTime -> IO Client newClient nextClientId qSize thVersion sessionId createdAt = do - clientId <- stateTVar nextClientId $ \next -> (next, next + 1) - subscriptions <- TM.empty - ntfSubscriptions <- TM.empty - rcvQ <- newTBQueue qSize - sndQ <- newTBQueue qSize - msgQ <- newTBQueue qSize - procThreads <- newTVar 0 - endThreads <- newTVar IM.empty - endThreadSeq <- newTVar 0 - connected <- newTVar True - rcvActiveAt <- newTVar createdAt - sndActiveAt <- newTVar createdAt + clientId <- atomically $ stateTVar nextClientId $ \next -> (next, next + 1) + subscriptions <- TM.emptyIO + ntfSubscriptions <- TM.emptyIO + rcvQ <- newTBQueueIO qSize + sndQ <- newTBQueueIO qSize + msgQ <- newTBQueueIO qSize + procThreads <- newTVarIO 0 + endThreads <- newTVarIO IM.empty + endThreadSeq <- newTVarIO 0 + connected <- newTVarIO True + rcvActiveAt <- newTVarIO createdAt + sndActiveAt <- newTVarIO createdAt return Client {clientId, subscriptions, ntfSubscriptions, rcvQ, sndQ, msgQ, procThreads, endThreads, endThreadSeq, thVersion, sessionId, connected, createdAt, rcvActiveAt, sndActiveAt} newSubscription :: SubscriptionThread -> STM Sub newSubscription st = do delivered <- newEmptyTMVar - subThread <- newTVar st + subThread <- ServerSub <$> newTVar st return Sub {subThread, delivered} +newProhibitedSub :: STM Sub +newProhibitedSub = do + delivered <- newEmptyTMVar + return Sub {subThread = ProhibitSub, delivered} + newEnv :: ServerConfig -> IO Env newEnv config@ServerConfig {caCertificateFile, certificateFile, privateKeyFile, storeLogFile, smpAgentCfg, transportConfig, information, messageExpiration} = do - server <- atomically newServer - queueStore <- atomically newQueueStore - msgStore <- atomically newMsgStore - random <- liftIO C.newRandom - storeLog <- restoreQueues queueStore `mapM` storeLogFile + server <- newServer + queueStore <- newQueueStore + msgStore <- newMsgStore + random <- C.newRandom + storeLog <- + forM storeLogFile $ \f -> do + logInfo $ "restoring queues from file " <> T.pack f + restoreQueues queueStore f tlsServerParams <- loadTLSServerParams caCertificateFile certificateFile privateKeyFile (alpn transportConfig) Fingerprint fp <- loadFingerprint caCertificateFile let serverIdentity = KeyHash fp - serverStats <- atomically . newServerStats =<< getCurrentTime - sockets <- atomically newSocketState + serverStats <- newServerStats =<< getCurrentTime + sockets <- newSocketState clientSeq <- newTVarIO 0 clients <- newTVarIO mempty - proxyAgent <- atomically $ newSMPProxyAgent smpAgentCfg random + proxyAgent <- newSMPProxyAgent smpAgentCfg random pure Env {config, serverInfo, server, serverIdentity, queueStore, msgStore, random, storeLog, tlsServerParams, serverStats, sockets, clientSeq, clients, proxyAgent} where restoreQueues :: QueueStore -> FilePath -> IO (StoreLog 'WriteMode) restoreQueues QueueStore {queues, senders, notifiers} f = do (qs, s) <- readWriteStoreLog f - atomically $ do - writeTVar queues =<< mapM newTVar qs - writeTVar senders $! M.foldr' addSender M.empty qs - writeTVar notifiers $! M.foldr' addNotifier M.empty qs + atomically . writeTVar queues =<< mapM newTVarIO qs + atomically $ writeTVar senders $! M.foldr' addSender M.empty qs + atomically $ writeTVar notifiers $! M.foldr' addNotifier M.empty qs pure s addSender :: QueueRec -> Map SenderId RecipientId -> Map SenderId RecipientId addSender q = M.insert (senderId q) (recipientId q) @@ -247,7 +261,7 @@ newEnv config@ServerConfig {caCertificateFile, certificateFile, privateKeyFile, | isJust (storeMsgsFile config) = SPMMessages | otherwise = SPMQueues -newSMPProxyAgent :: SMPClientAgentConfig -> TVar ChaChaDRG -> STM ProxyAgent +newSMPProxyAgent :: SMPClientAgentConfig -> TVar ChaChaDRG -> IO ProxyAgent newSMPProxyAgent smpAgentCfg random = do smpAgent <- newSMPClientAgent smpAgentCfg random pure ProxyAgent {smpAgent} diff --git a/src/Simplex/Messaging/Server/MsgStore/STM.hs b/src/Simplex/Messaging/Server/MsgStore/STM.hs index c8f78e2fb..e0a5c8b45 100644 --- a/src/Simplex/Messaging/Server/MsgStore/STM.hs +++ b/src/Simplex/Messaging/Server/MsgStore/STM.hs @@ -14,8 +14,6 @@ module Simplex.Messaging.Server.MsgStore.STM getMsgQueue, delMsgQueue, delMsgQueueSize, - flushMsgQueue, - snapshotMsgQueue, writeMsg, tryPeekMsg, peekMsg, @@ -25,7 +23,6 @@ module Simplex.Messaging.Server.MsgStore.STM ) where -import Control.Concurrent.STM.TQueue (flushTQueue) import qualified Data.ByteString.Char8 as B import Data.Functor (($>)) import Data.Int (Int64) @@ -44,8 +41,8 @@ data MsgQueue = MsgQueue type STMMsgStore = TMap RecipientId MsgQueue -newMsgStore :: STM STMMsgStore -newMsgStore = TM.empty +newMsgStore :: IO STMMsgStore +newMsgStore = TM.emptyIO getMsgQueue :: STMMsgStore -> RecipientId -> Int -> STM MsgQueue getMsgQueue st rId quota = maybe newQ pure =<< TM.lookup rId st @@ -64,17 +61,6 @@ delMsgQueue st rId = TM.delete rId st delMsgQueueSize :: STMMsgStore -> RecipientId -> STM Int delMsgQueueSize st rId = TM.lookupDelete rId st >>= maybe (pure 0) (\MsgQueue {size} -> readTVar size) -flushMsgQueue :: STMMsgStore -> RecipientId -> STM [Message] -flushMsgQueue st rId = TM.lookupDelete rId st >>= maybe (pure []) (flushTQueue . msgQueue) - -snapshotMsgQueue :: STMMsgStore -> RecipientId -> STM [Message] -snapshotMsgQueue st rId = TM.lookup rId st >>= maybe (pure []) (snapshotTQueue . msgQueue) - where - snapshotTQueue q = do - msgs <- flushTQueue q - mapM_ (writeTQueue q) msgs - pure msgs - writeMsg :: MsgQueue -> Message -> STM (Maybe (Message, Bool)) writeMsg MsgQueue {msgQueue = q, quota, canWrite, size} !msg = do canWrt <- readTVar canWrite diff --git a/src/Simplex/Messaging/Server/QueueStore/STM.hs b/src/Simplex/Messaging/Server/QueueStore/STM.hs index d6cdaf10a..50907cf9a 100644 --- a/src/Simplex/Messaging/Server/QueueStore/STM.hs +++ b/src/Simplex/Messaging/Server/QueueStore/STM.hs @@ -38,11 +38,11 @@ data QueueStore = QueueStore notifiers :: TMap NotifierId RecipientId } -newQueueStore :: STM QueueStore +newQueueStore :: IO QueueStore newQueueStore = do - queues <- TM.empty - senders <- TM.empty - notifiers <- TM.empty + queues <- TM.emptyIO + senders <- TM.emptyIO + notifiers <- TM.emptyIO pure QueueStore {queues, senders, notifiers} addQueue :: QueueStore -> QueueRec -> STM (Either ErrorType ()) diff --git a/src/Simplex/Messaging/Server/Stats.hs b/src/Simplex/Messaging/Server/Stats.hs index f2716c9c3..f5b430bb6 100644 --- a/src/Simplex/Messaging/Server/Stats.hs +++ b/src/Simplex/Messaging/Server/Stats.hs @@ -27,16 +27,29 @@ data ServerStats = ServerStats qDeletedNew :: TVar Int, qDeletedSecured :: TVar Int, qSub :: TVar Int, + qSubNoMsg :: TVar Int, qSubAuth :: TVar Int, qSubDuplicate :: TVar Int, qSubProhibited :: TVar Int, + ntfCreated :: TVar Int, + ntfDeleted :: TVar Int, + ntfSub :: TVar Int, + ntfSubAuth :: TVar Int, + ntfSubDuplicate :: TVar Int, msgSent :: TVar Int, msgSentAuth :: TVar Int, msgSentQuota :: TVar Int, msgSentLarge :: TVar Int, msgRecv :: TVar Int, + msgRecvGet :: TVar Int, + msgGet :: TVar Int, + msgGetNoMsg :: TVar Int, + msgGetAuth :: TVar Int, + msgGetDuplicate :: TVar Int, + msgGetProhibited :: TVar Int, msgExpired :: TVar Int, activeQueues :: PeriodStats RecipientId, + subscribedQueues :: PeriodStats RecipientId, msgSentNtf :: TVar Int, -- sent messages with NTF flag msgRecvNtf :: TVar Int, -- received messages with NTF flag activeQueuesNtf :: PeriodStats RecipientId, @@ -60,16 +73,29 @@ data ServerStatsData = ServerStatsData _qDeletedNew :: Int, _qDeletedSecured :: Int, _qSub :: Int, + _qSubNoMsg :: Int, _qSubAuth :: Int, _qSubDuplicate :: Int, _qSubProhibited :: Int, + _ntfCreated :: Int, + _ntfDeleted :: Int, + _ntfSub :: Int, + _ntfSubAuth :: Int, + _ntfSubDuplicate :: Int, _msgSent :: Int, _msgSentAuth :: Int, _msgSentQuota :: Int, _msgSentLarge :: Int, _msgRecv :: Int, + _msgRecvGet :: Int, + _msgGet :: Int, + _msgGetNoMsg :: Int, + _msgGetAuth :: Int, + _msgGetDuplicate :: Int, + _msgGetProhibited :: Int, _msgExpired :: Int, _activeQueues :: PeriodStatsData RecipientId, + _subscribedQueues :: PeriodStatsData RecipientId, _msgSentNtf :: Int, _msgRecvNtf :: Int, _activeQueuesNtf :: PeriodStatsData RecipientId, @@ -86,38 +112,51 @@ data ServerStatsData = ServerStatsData } deriving (Show) -newServerStats :: UTCTime -> STM ServerStats +newServerStats :: UTCTime -> IO ServerStats newServerStats ts = do - fromTime <- newTVar ts - qCreated <- newTVar 0 - qSecured <- newTVar 0 - qDeletedAll <- newTVar 0 - qDeletedNew <- newTVar 0 - qDeletedSecured <- newTVar 0 - qSub <- newTVar 0 - qSubAuth <- newTVar 0 - qSubDuplicate <- newTVar 0 - qSubProhibited <- newTVar 0 - msgSent <- newTVar 0 - msgSentAuth <- newTVar 0 - msgSentQuota <- newTVar 0 - msgSentLarge <- newTVar 0 - msgRecv <- newTVar 0 - msgExpired <- newTVar 0 + fromTime <- newTVarIO ts + qCreated <- newTVarIO 0 + qSecured <- newTVarIO 0 + qDeletedAll <- newTVarIO 0 + qDeletedNew <- newTVarIO 0 + qDeletedSecured <- newTVarIO 0 + qSub <- newTVarIO 0 + qSubNoMsg <- newTVarIO 0 + qSubAuth <- newTVarIO 0 + qSubDuplicate <- newTVarIO 0 + qSubProhibited <- newTVarIO 0 + ntfCreated <- newTVarIO 0 + ntfDeleted <- newTVarIO 0 + ntfSub <- newTVarIO 0 + ntfSubAuth <- newTVarIO 0 + ntfSubDuplicate <- newTVarIO 0 + msgSent <- newTVarIO 0 + msgSentAuth <- newTVarIO 0 + msgSentQuota <- newTVarIO 0 + msgSentLarge <- newTVarIO 0 + msgRecv <- newTVarIO 0 + msgRecvGet <- newTVarIO 0 + msgGet <- newTVarIO 0 + msgGetNoMsg <- newTVarIO 0 + msgGetAuth <- newTVarIO 0 + msgGetDuplicate <- newTVarIO 0 + msgGetProhibited <- newTVarIO 0 + msgExpired <- newTVarIO 0 activeQueues <- newPeriodStats - msgSentNtf <- newTVar 0 - msgRecvNtf <- newTVar 0 + subscribedQueues <- newPeriodStats + msgSentNtf <- newTVarIO 0 + msgRecvNtf <- newTVarIO 0 activeQueuesNtf <- newPeriodStats - msgNtfs <- newTVar 0 - msgNtfNoSub <- newTVar 0 - msgNtfLost <- newTVar 0 + msgNtfs <- newTVarIO 0 + msgNtfNoSub <- newTVarIO 0 + msgNtfLost <- newTVarIO 0 pRelays <- newProxyStats pRelaysOwn <- newProxyStats pMsgFwds <- newProxyStats pMsgFwdsOwn <- newProxyStats - pMsgFwdsRecv <- newTVar 0 - qCount <- newTVar 0 - msgCount <- newTVar 0 + pMsgFwdsRecv <- newTVarIO 0 + qCount <- newTVarIO 0 + msgCount <- newTVarIO 0 pure ServerStats { fromTime, @@ -127,16 +166,29 @@ newServerStats ts = do qDeletedNew, qDeletedSecured, qSub, + qSubNoMsg, qSubAuth, qSubDuplicate, qSubProhibited, + ntfCreated, + ntfDeleted, + ntfSub, + ntfSubAuth, + ntfSubDuplicate, msgSent, msgSentAuth, msgSentQuota, msgSentLarge, msgRecv, + msgRecvGet, + msgGet, + msgGetNoMsg, + msgGetAuth, + msgGetDuplicate, + msgGetProhibited, msgExpired, activeQueues, + subscribedQueues, msgSentNtf, msgRecvNtf, activeQueuesNtf, @@ -152,38 +204,51 @@ newServerStats ts = do msgCount } -getServerStatsData :: ServerStats -> STM ServerStatsData +getServerStatsData :: ServerStats -> IO ServerStatsData getServerStatsData s = do - _fromTime <- readTVar $ fromTime s - _qCreated <- readTVar $ qCreated s - _qSecured <- readTVar $ qSecured s - _qDeletedAll <- readTVar $ qDeletedAll s - _qDeletedNew <- readTVar $ qDeletedNew s - _qDeletedSecured <- readTVar $ qDeletedSecured s - _qSub <- readTVar $ qSub s - _qSubAuth <- readTVar $ qSubAuth s - _qSubDuplicate <- readTVar $ qSubDuplicate s - _qSubProhibited <- readTVar $ qSubProhibited s - _msgSent <- readTVar $ msgSent s - _msgSentAuth <- readTVar $ msgSentAuth s - _msgSentQuota <- readTVar $ msgSentQuota s - _msgSentLarge <- readTVar $ msgSentLarge s - _msgRecv <- readTVar $ msgRecv s - _msgExpired <- readTVar $ msgExpired s + _fromTime <- readTVarIO $ fromTime s + _qCreated <- readTVarIO $ qCreated s + _qSecured <- readTVarIO $ qSecured s + _qDeletedAll <- readTVarIO $ qDeletedAll s + _qDeletedNew <- readTVarIO $ qDeletedNew s + _qDeletedSecured <- readTVarIO $ qDeletedSecured s + _qSub <- readTVarIO $ qSub s + _qSubNoMsg <- readTVarIO $ qSubNoMsg s + _qSubAuth <- readTVarIO $ qSubAuth s + _qSubDuplicate <- readTVarIO $ qSubDuplicate s + _qSubProhibited <- readTVarIO $ qSubProhibited s + _ntfCreated <- readTVarIO $ ntfCreated s + _ntfDeleted <- readTVarIO $ ntfDeleted s + _ntfSub <- readTVarIO $ ntfSub s + _ntfSubAuth <- readTVarIO $ ntfSubAuth s + _ntfSubDuplicate <- readTVarIO $ ntfSubDuplicate s + _msgSent <- readTVarIO $ msgSent s + _msgSentAuth <- readTVarIO $ msgSentAuth s + _msgSentQuota <- readTVarIO $ msgSentQuota s + _msgSentLarge <- readTVarIO $ msgSentLarge s + _msgRecv <- readTVarIO $ msgRecv s + _msgRecvGet <- readTVarIO $ msgRecvGet s + _msgGet <- readTVarIO $ msgGet s + _msgGetNoMsg <- readTVarIO $ msgGetNoMsg s + _msgGetAuth <- readTVarIO $ msgGetAuth s + _msgGetDuplicate <- readTVarIO $ msgGetDuplicate s + _msgGetProhibited <- readTVarIO $ msgGetProhibited s + _msgExpired <- readTVarIO $ msgExpired s _activeQueues <- getPeriodStatsData $ activeQueues s - _msgSentNtf <- readTVar $ msgSentNtf s - _msgRecvNtf <- readTVar $ msgRecvNtf s + _subscribedQueues <- getPeriodStatsData $ subscribedQueues s + _msgSentNtf <- readTVarIO $ msgSentNtf s + _msgRecvNtf <- readTVarIO $ msgRecvNtf s _activeQueuesNtf <- getPeriodStatsData $ activeQueuesNtf s - _msgNtfs <- readTVar $ msgNtfs s - _msgNtfNoSub <- readTVar $ msgNtfNoSub s - _msgNtfLost <- readTVar $ msgNtfLost s + _msgNtfs <- readTVarIO $ msgNtfs s + _msgNtfNoSub <- readTVarIO $ msgNtfNoSub s + _msgNtfLost <- readTVarIO $ msgNtfLost s _pRelays <- getProxyStatsData $ pRelays s _pRelaysOwn <- getProxyStatsData $ pRelaysOwn s _pMsgFwds <- getProxyStatsData $ pMsgFwds s _pMsgFwdsOwn <- getProxyStatsData $ pMsgFwdsOwn s - _pMsgFwdsRecv <- readTVar $ pMsgFwdsRecv s - _qCount <- readTVar $ qCount s - _msgCount <- readTVar $ msgCount s + _pMsgFwdsRecv <- readTVarIO $ pMsgFwdsRecv s + _qCount <- readTVarIO $ qCount s + _msgCount <- readTVarIO $ msgCount s pure ServerStatsData { _fromTime, @@ -193,16 +258,29 @@ getServerStatsData s = do _qDeletedNew, _qDeletedSecured, _qSub, + _qSubNoMsg, _qSubAuth, _qSubDuplicate, _qSubProhibited, + _ntfCreated, + _ntfDeleted, + _ntfSub, + _ntfSubAuth, + _ntfSubDuplicate, _msgSent, _msgSentAuth, _msgSentQuota, _msgSentLarge, _msgRecv, + _msgRecvGet, + _msgGet, + _msgGetNoMsg, + _msgGetAuth, + _msgGetDuplicate, + _msgGetProhibited, _msgExpired, _activeQueues, + _subscribedQueues, _msgSentNtf, _msgRecvNtf, _activeQueuesNtf, @@ -227,16 +305,29 @@ setServerStats s d = do writeTVar (qDeletedNew s) $! _qDeletedNew d writeTVar (qDeletedSecured s) $! _qDeletedSecured d writeTVar (qSub s) $! _qSub d + writeTVar (qSubNoMsg s) $! _qSubNoMsg d writeTVar (qSubAuth s) $! _qSubAuth d writeTVar (qSubDuplicate s) $! _qSubDuplicate d writeTVar (qSubProhibited s) $! _qSubProhibited d + writeTVar (ntfCreated s) $! _ntfCreated d + writeTVar (ntfDeleted s) $! _ntfDeleted d + writeTVar (ntfSub s) $! _ntfSub d + writeTVar (ntfSubAuth s) $! _ntfSubAuth d + writeTVar (ntfSubDuplicate s) $! _ntfSubDuplicate d writeTVar (msgSent s) $! _msgSent d writeTVar (msgSentAuth s) $! _msgSentAuth d writeTVar (msgSentQuota s) $! _msgSentQuota d writeTVar (msgSentLarge s) $! _msgSentLarge d writeTVar (msgRecv s) $! _msgRecv d + writeTVar (msgRecvGet s) $! _msgRecvGet d + writeTVar (msgGet s) $! _msgGet d + writeTVar (msgGetNoMsg s) $! _msgGetNoMsg d + writeTVar (msgGetAuth s) $! _msgGetAuth d + writeTVar (msgGetDuplicate s) $! _msgGetDuplicate d + writeTVar (msgGetProhibited s) $! _msgGetProhibited d writeTVar (msgExpired s) $! _msgExpired d setPeriodStats (activeQueues s) (_activeQueues d) + setPeriodStats (subscribedQueues s) (_subscribedQueues d) writeTVar (msgSentNtf s) $! _msgSentNtf d writeTVar (msgRecvNtf s) $! _msgRecvNtf d setPeriodStats (activeQueuesNtf s) (_activeQueuesNtf d) @@ -262,14 +353,26 @@ instance StrEncoding ServerStatsData where "qDeletedSecured=" <> strEncode (_qDeletedSecured d), "qCount=" <> strEncode (_qCount d), "qSub=" <> strEncode (_qSub d), + "qSubNoMsg=" <> strEncode (_qSubNoMsg d), "qSubAuth=" <> strEncode (_qSubAuth d), "qSubDuplicate=" <> strEncode (_qSubDuplicate d), "qSubProhibited=" <> strEncode (_qSubProhibited d), + "ntfCreated=" <> strEncode (_ntfCreated d), + "ntfDeleted=" <> strEncode (_ntfDeleted d), + "ntfSub=" <> strEncode (_ntfSub d), + "ntfSubAuth=" <> strEncode (_ntfSubAuth d), + "ntfSubDuplicate=" <> strEncode (_ntfSubDuplicate d), "msgSent=" <> strEncode (_msgSent d), "msgSentAuth=" <> strEncode (_msgSentAuth d), "msgSentQuota=" <> strEncode (_msgSentQuota d), "msgSentLarge=" <> strEncode (_msgSentLarge d), "msgRecv=" <> strEncode (_msgRecv d), + "msgRecvGet=" <> strEncode (_msgRecvGet d), + "msgGet=" <> strEncode (_msgGet d), + "msgGetNoMsg=" <> strEncode (_msgGetNoMsg d), + "msgGetAuth=" <> strEncode (_msgGetAuth d), + "msgGetDuplicate=" <> strEncode (_msgGetDuplicate d), + "msgGetProhibited=" <> strEncode (_msgGetProhibited d), "msgExpired=" <> strEncode (_msgExpired d), "msgSentNtf=" <> strEncode (_msgSentNtf d), "msgRecvNtf=" <> strEncode (_msgRecvNtf d), @@ -278,6 +381,8 @@ instance StrEncoding ServerStatsData where "msgNtfLost=" <> strEncode (_msgNtfLost d), "activeQueues:", strEncode (_activeQueues d), + "subscribedQueues:", + strEncode (_subscribedQueues d), "activeQueuesNtf:", strEncode (_activeQueuesNtf d), "pRelays:", @@ -299,14 +404,26 @@ instance StrEncoding ServerStatsData where <|> ((,,) <$> ("qDeletedAll=" *> strP <* A.endOfLine) <*> ("qDeletedNew=" *> strP <* A.endOfLine) <*> ("qDeletedSecured=" *> strP <* A.endOfLine)) _qCount <- opt "qCount=" _qSub <- opt "qSub=" + _qSubNoMsg <- opt "qSubNoMsg=" _qSubAuth <- opt "qSubAuth=" _qSubDuplicate <- opt "qSubDuplicate=" _qSubProhibited <- opt "qSubProhibited=" + _ntfCreated <- opt "ntfCreated=" + _ntfDeleted <- opt "ntfDeleted=" + _ntfSub <- opt "ntfSub=" + _ntfSubAuth <- opt "ntfSubAuth=" + _ntfSubDuplicate <- opt "ntfSubDuplicate=" _msgSent <- "msgSent=" *> strP <* A.endOfLine _msgSentAuth <- opt "msgSentAuth=" _msgSentQuota <- opt "msgSentQuota=" _msgSentLarge <- opt "msgSentLarge=" _msgRecv <- "msgRecv=" *> strP <* A.endOfLine + _msgRecvGet <- opt "msgRecvGet=" + _msgGet <- opt "msgGet=" + _msgGetNoMsg <- opt "msgGetNoMsg=" + _msgGetAuth <- opt "msgGetAuth=" + _msgGetDuplicate <- opt "msgGetDuplicate=" + _msgGetProhibited <- opt "msgGetProhibited=" _msgExpired <- opt "msgExpired=" _msgSentNtf <- opt "msgSentNtf=" _msgRecvNtf <- opt "msgRecvNtf=" @@ -321,6 +438,10 @@ instance StrEncoding ServerStatsData where _week <- "weekMsgQueues=" *> strP <* A.endOfLine _month <- "monthMsgQueues=" *> strP <* optional A.endOfLine pure PeriodStatsData {_day, _week, _month} + _subscribedQueues <- + optional ("subscribedQueues:" <* A.endOfLine) >>= \case + Just _ -> strP <* optional A.endOfLine + _ -> pure newPeriodStatsData _activeQueuesNtf <- optional ("activeQueuesNtf:" <* A.endOfLine) >>= \case Just _ -> strP <* optional A.endOfLine @@ -339,14 +460,26 @@ instance StrEncoding ServerStatsData where _qDeletedNew, _qDeletedSecured, _qSub, + _qSubNoMsg, _qSubAuth, _qSubDuplicate, _qSubProhibited, + _ntfCreated, + _ntfDeleted, + _ntfSub, + _ntfSubAuth, + _ntfSubDuplicate, _msgSent, _msgSentAuth, _msgSentQuota, _msgSentLarge, _msgRecv, + _msgRecvGet, + _msgGet, + _msgGetNoMsg, + _msgGetAuth, + _msgGetDuplicate, + _msgGetProhibited, _msgExpired, _msgSentNtf, _msgRecvNtf, @@ -354,6 +487,7 @@ instance StrEncoding ServerStatsData where _msgNtfNoSub, _msgNtfLost, _activeQueues, + _subscribedQueues, _activeQueuesNtf, _pRelays, _pRelaysOwn, @@ -376,11 +510,11 @@ data PeriodStats a = PeriodStats month :: TVar (Set a) } -newPeriodStats :: STM (PeriodStats a) +newPeriodStats :: IO (PeriodStats a) newPeriodStats = do - day <- newTVar S.empty - week <- newTVar S.empty - month <- newTVar S.empty + day <- newTVarIO S.empty + week <- newTVarIO S.empty + month <- newTVarIO S.empty pure PeriodStats {day, week, month} data PeriodStatsData a = PeriodStatsData @@ -393,11 +527,11 @@ data PeriodStatsData a = PeriodStatsData newPeriodStatsData :: PeriodStatsData a newPeriodStatsData = PeriodStatsData {_day = S.empty, _week = S.empty, _month = S.empty} -getPeriodStatsData :: PeriodStats a -> STM (PeriodStatsData a) +getPeriodStatsData :: PeriodStats a -> IO (PeriodStatsData a) getPeriodStatsData s = do - _day <- readTVar $ day s - _week <- readTVar $ week s - _month <- readTVar $ month s + _day <- readTVarIO $ day s + _week <- readTVarIO $ week s + _month <- readTVarIO $ month s pure PeriodStatsData {_day, _week, _month} setPeriodStats :: PeriodStats a -> PeriodStatsData a -> STM () @@ -451,13 +585,13 @@ data ProxyStats = ProxyStats pErrorsOther :: TVar Int } -newProxyStats :: STM ProxyStats +newProxyStats :: IO ProxyStats newProxyStats = do - pRequests <- newTVar 0 - pSuccesses <- newTVar 0 - pErrorsConnect <- newTVar 0 - pErrorsCompat <- newTVar 0 - pErrorsOther <- newTVar 0 + pRequests <- newTVarIO 0 + pSuccesses <- newTVarIO 0 + pErrorsConnect <- newTVarIO 0 + pErrorsCompat <- newTVarIO 0 + pErrorsOther <- newTVarIO 0 pure ProxyStats {pRequests, pSuccesses, pErrorsConnect, pErrorsCompat, pErrorsOther} data ProxyStatsData = ProxyStatsData @@ -472,13 +606,13 @@ data ProxyStatsData = ProxyStatsData newProxyStatsData :: ProxyStatsData newProxyStatsData = ProxyStatsData {_pRequests = 0, _pSuccesses = 0, _pErrorsConnect = 0, _pErrorsCompat = 0, _pErrorsOther = 0} -getProxyStatsData :: ProxyStats -> STM ProxyStatsData +getProxyStatsData :: ProxyStats -> IO ProxyStatsData getProxyStatsData s = do - _pRequests <- readTVar $ pRequests s - _pSuccesses <- readTVar $ pSuccesses s - _pErrorsConnect <- readTVar $ pErrorsConnect s - _pErrorsCompat <- readTVar $ pErrorsCompat s - _pErrorsOther <- readTVar $ pErrorsOther s + _pRequests <- readTVarIO $ pRequests s + _pSuccesses <- readTVarIO $ pSuccesses s + _pErrorsConnect <- readTVarIO $ pErrorsConnect s + _pErrorsCompat <- readTVarIO $ pErrorsCompat s + _pErrorsOther <- readTVarIO $ pErrorsOther s pure ProxyStatsData {_pRequests, _pSuccesses, _pErrorsConnect, _pErrorsCompat, _pErrorsOther} getResetProxyStatsData :: ProxyStats -> STM ProxyStatsData diff --git a/src/Simplex/Messaging/Session.hs b/src/Simplex/Messaging/Session.hs index 3ce5a35c8..45c182046 100644 --- a/src/Simplex/Messaging/Session.hs +++ b/src/Simplex/Messaging/Session.hs @@ -5,9 +5,6 @@ module Simplex.Messaging.Session where import Control.Concurrent.STM -import Control.Monad -import Data.Composition ((.:.)) -import Data.Functor (($>)) import Data.Time (UTCTime) import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM @@ -31,14 +28,10 @@ getSessVar sessSeq sessKey vs sessionVarTs = maybe (Left <$> newSessionVar) (pur pure v removeSessVar :: Ord k => SessionVar a -> k -> TMap k (SessionVar a) -> STM () -removeSessVar = void .:. removeSessVar' -{-# INLINE removeSessVar #-} - -removeSessVar' :: Ord k => SessionVar a -> k -> TMap k (SessionVar a) -> STM Bool -removeSessVar' v sessKey vs = +removeSessVar v sessKey vs = TM.lookup sessKey vs >>= \case - Just v' | sessionVarId v == sessionVarId v' -> TM.delete sessKey vs $> True - _ -> pure False + Just v' | sessionVarId v == sessionVarId v' -> TM.delete sessKey vs + _ -> pure () tryReadSessVar :: Ord k => k -> TMap k (SessionVar a) -> STM (Maybe a) tryReadSessVar sessKey vs = TM.lookup sessKey vs $>>= (tryReadTMVar . sessionVar) diff --git a/src/Simplex/Messaging/TMap.hs b/src/Simplex/Messaging/TMap.hs index 2f6e0cf8a..1bc9bcb60 100644 --- a/src/Simplex/Messaging/TMap.hs +++ b/src/Simplex/Messaging/TMap.hs @@ -1,11 +1,13 @@ module Simplex.Messaging.TMap ( TMap, - empty, + emptyIO, singleton, clear, Simplex.Messaging.TMap.null, Simplex.Messaging.TMap.lookup, + lookupIO, member, + memberIO, insert, delete, lookupInsert, @@ -24,9 +26,9 @@ import qualified Data.Map.Strict as M type TMap k a = TVar (Map k a) -empty :: STM (TMap k a) -empty = newTVar M.empty -{-# INLINE empty #-} +emptyIO :: IO (TMap k a) +emptyIO = newTVarIO M.empty +{-# INLINE emptyIO #-} singleton :: k -> a -> STM (TMap k a) singleton k v = newTVar $ M.singleton k v @@ -44,10 +46,18 @@ lookup :: Ord k => k -> TMap k a -> STM (Maybe a) lookup k m = M.lookup k <$> readTVar m {-# INLINE lookup #-} +lookupIO :: Ord k => k -> TMap k a -> IO (Maybe a) +lookupIO k m = M.lookup k <$> readTVarIO m +{-# INLINE lookupIO #-} + member :: Ord k => k -> TMap k a -> STM Bool member k m = M.member k <$> readTVar m {-# INLINE member #-} +memberIO :: Ord k => k -> TMap k a -> IO Bool +memberIO k m = M.member k <$> readTVarIO m +{-# INLINE memberIO #-} + insert :: Ord k => k -> a -> TMap k a -> STM () insert k v m = modifyTVar' m $ M.insert k v {-# INLINE insert #-} diff --git a/src/Simplex/Messaging/Transport.hs b/src/Simplex/Messaging/Transport.hs index 16bff693a..3386f82f3 100644 --- a/src/Simplex/Messaging/Transport.hs +++ b/src/Simplex/Messaging/Transport.hs @@ -286,7 +286,7 @@ getTLS :: TransportPeer -> TransportConfig -> X.CertificateChain -> T.Context -> getTLS tlsPeer cfg tlsServerCerts cxt = withTlsUnique tlsPeer cxt newTLS where newTLS tlsUniq = do - tlsBuffer <- atomically newTBuffer + tlsBuffer <- newTBuffer tlsALPN <- T.getNegotiatedProtocol cxt pure TLS {tlsContext = cxt, tlsALPN, tlsTransportConfig = cfg, tlsServerCerts, tlsPeer, tlsUniq, tlsBuffer} diff --git a/src/Simplex/Messaging/Transport/Buffer.hs b/src/Simplex/Messaging/Transport/Buffer.hs index 6de9326f8..a612afafc 100644 --- a/src/Simplex/Messaging/Transport/Buffer.hs +++ b/src/Simplex/Messaging/Transport/Buffer.hs @@ -17,10 +17,10 @@ data TBuffer = TBuffer getLock :: TMVar () } -newTBuffer :: STM TBuffer +newTBuffer :: IO TBuffer newTBuffer = do - buffer <- newTVar "" - getLock <- newTMVar () + buffer <- newTVarIO "" + getLock <- newTMVarIO () pure TBuffer {buffer, getLock} withBufferLock :: TBuffer -> IO a -> IO a diff --git a/src/Simplex/Messaging/Transport/HTTP2.hs b/src/Simplex/Messaging/Transport/HTTP2.hs index 9c6cd7abc..3b741e6ce 100644 --- a/src/Simplex/Messaging/Transport/HTTP2.hs +++ b/src/Simplex/Messaging/Transport/HTTP2.hs @@ -75,7 +75,7 @@ instance HTTP2BodyChunk HS.Request where getHTTP2Body :: HTTP2BodyChunk a => a -> Int -> IO HTTP2Body getHTTP2Body r n = do - bodyBuffer <- atomically newTBuffer + bodyBuffer <- newTBuffer let getPart n' = getBuffered bodyBuffer n' Nothing $ getBodyChunk r bodyHead <- getPart n let bodySize = fromMaybe 0 $ getBodySize r diff --git a/src/Simplex/Messaging/Transport/HTTP2/Client.hs b/src/Simplex/Messaging/Transport/HTTP2/Client.hs index 71757ca6d..d8d3d495d 100644 --- a/src/Simplex/Messaging/Transport/HTTP2/Client.hs +++ b/src/Simplex/Messaging/Transport/HTTP2/Client.hs @@ -104,13 +104,13 @@ attachHTTP2Client config host port disconnected bufferSize tls = getVerifiedHTTP getVerifiedHTTP2ClientWith :: HTTP2ClientConfig -> TransportHost -> ServiceName -> IO () -> ((TLS -> H.Client HTTP2Response) -> IO HTTP2Response) -> IO (Either HTTP2ClientError HTTP2Client) getVerifiedHTTP2ClientWith config host port disconnected setup = - (atomically mkHTTPS2Client >>= runClient) + (mkHTTPS2Client >>= runClient) `E.catch` \(e :: IOException) -> pure . Left $ HCIOError e where - mkHTTPS2Client :: STM HClient + mkHTTPS2Client :: IO HClient mkHTTPS2Client = do - connected <- newTVar False - reqQ <- newTBQueue $ qSize config + connected <- newTVarIO False + reqQ <- newTBQueueIO $ qSize config pure HClient {connected, disconnected, host, port, config, reqQ} runClient :: HClient -> IO (Either HTTP2ClientError HTTP2Client) diff --git a/src/Simplex/Messaging/Transport/Server.hs b/src/Simplex/Messaging/Transport/Server.hs index ffde39991..0b4da7833 100644 --- a/src/Simplex/Messaging/Transport/Server.hs +++ b/src/Simplex/Messaging/Transport/Server.hs @@ -76,7 +76,7 @@ serverTransportConfig TransportServerConfig {logTLSErrors} = -- All accepted connections are passed to the passed function. runTransportServer :: forall c. Transport c => TMVar Bool -> ServiceName -> T.ServerParams -> TransportServerConfig -> (c -> IO ()) -> IO () runTransportServer started port params cfg server = do - ss <- atomically newSocketState + ss <- newSocketState runTransportServerState ss started port params cfg server runTransportServerState :: forall c . Transport c => SocketState -> TMVar Bool -> ServiceName -> T.ServerParams -> TransportServerConfig -> (c -> IO ()) -> IO () @@ -85,7 +85,7 @@ runTransportServerState ss started port = runTransportServerSocketState ss start -- | Run a transport server with provided connection setup and handler. runTransportServerSocket :: Transport a => TMVar Bool -> IO Socket -> String -> T.ServerParams -> TransportServerConfig -> (a -> IO ()) -> IO () runTransportServerSocket started getSocket threadLabel serverParams cfg server = do - ss <- atomically newSocketState + ss <- newSocketState runTransportServerSocketState ss started getSocket threadLabel serverParams cfg server -- | Run a transport server with provided connection setup and handler. @@ -109,7 +109,7 @@ tlsServerCredentials serverParams = case T.sharedCredentials $ T.serverShared se -- | Run TCP server without TLS runTCPServer :: TMVar Bool -> ServiceName -> (Socket -> IO ()) -> IO () runTCPServer started port server = do - ss <- atomically newSocketState + ss <- newSocketState runTCPServerSocket ss started (startTCPServer started port) server -- | Wrap socket provider in a TCP server bracket. @@ -148,8 +148,8 @@ safeAccept sock = type SocketState = (TVar Int, TVar Int, TVar (IntMap (Weak ThreadId))) -newSocketState :: STM SocketState -newSocketState = (,,) <$> newTVar 0 <*> newTVar 0 <*> newTVar mempty +newSocketState :: IO SocketState +newSocketState = (,,) <$> newTVarIO 0 <*> newTVarIO 0 <*> newTVarIO mempty closeServer :: TMVar Bool -> TVar (IntMap (Weak ThreadId)) -> Socket -> IO () closeServer started clients sock = do diff --git a/tests/AgentTests/ConnectionRequestTests.hs b/tests/AgentTests/ConnectionRequestTests.hs index 8684c787c..5d0a2c00a 100644 --- a/tests/AgentTests/ConnectionRequestTests.hs +++ b/tests/AgentTests/ConnectionRequestTests.hs @@ -225,23 +225,23 @@ connectionRequestTests = queueV1NoPort #== ("smp://1234-w==@smp.simplex.im/3456-w==#/?v=1-1&dh=" <> url testDhKeyStr <> "&srv=jjbyvoemxysm7qxap7m5d5m35jzv5qq6gnlv7s4rsn7tdwwmuqciwpid.onion") queueV1NoPort #== ("smp://1234-w==@smp.simplex.im,jjbyvoemxysm7qxap7m5d5m35jzv5qq6gnlv7s4rsn7tdwwmuqciwpid.onion/3456-w==#" <> testDhKeyStr) it "should serialize and parse connection invitations and contact addresses" $ do - connectionRequest #==# ("simplex:/invitation#/?v=2-6&smp=" <> url queueStr <> "&e2e=" <> testE2ERatchetParamsStrUri) - connectionRequest #== ("https://simplex.chat/invitation#/?v=2-6&smp=" <> url queueStr <> "&e2e=" <> testE2ERatchetParamsStrUri) - connectionRequestSK #==# ("simplex:/invitation#/?v=2-6&smp=" <> url queueStrSK <> "&e2e=" <> testE2ERatchetParamsStrUri) - connectionRequest1 #==# ("simplex:/invitation#/?v=2-6&smp=" <> url queue1Str <> "&e2e=" <> testE2ERatchetParamsStrUri) - connectionRequest2queues #==# ("simplex:/invitation#/?v=2-6&smp=" <> url (queueStr <> ";" <> queueStr) <> "&e2e=" <> testE2ERatchetParamsStrUri) - connectionRequestNew #==# ("simplex:/invitation#/?v=2-6&smp=" <> url queueNewStr <> "&e2e=" <> testE2ERatchetParamsStrUri) - connectionRequestNew1 #==# ("simplex:/invitation#/?v=2-6&smp=" <> url queueNew1Str <> "&e2e=" <> testE2ERatchetParamsStrUri) - connectionRequest2queuesNew #==# ("simplex:/invitation#/?v=2-6&smp=" <> url (queueNewStr <> ";" <> queueNewStr) <> "&e2e=" <> testE2ERatchetParamsStrUri) + connectionRequest #==# ("simplex:/invitation#/?v=2-7&smp=" <> url queueStr <> "&e2e=" <> testE2ERatchetParamsStrUri) + connectionRequest #== ("https://simplex.chat/invitation#/?v=2-7&smp=" <> url queueStr <> "&e2e=" <> testE2ERatchetParamsStrUri) + connectionRequestSK #==# ("simplex:/invitation#/?v=2-7&smp=" <> url queueStrSK <> "&e2e=" <> testE2ERatchetParamsStrUri) + connectionRequest1 #==# ("simplex:/invitation#/?v=2-7&smp=" <> url queue1Str <> "&e2e=" <> testE2ERatchetParamsStrUri) + connectionRequest2queues #==# ("simplex:/invitation#/?v=2-7&smp=" <> url (queueStr <> ";" <> queueStr) <> "&e2e=" <> testE2ERatchetParamsStrUri) + connectionRequestNew #==# ("simplex:/invitation#/?v=2-7&smp=" <> url queueNewStr <> "&e2e=" <> testE2ERatchetParamsStrUri) + connectionRequestNew1 #==# ("simplex:/invitation#/?v=2-7&smp=" <> url queueNew1Str <> "&e2e=" <> testE2ERatchetParamsStrUri) + connectionRequest2queuesNew #==# ("simplex:/invitation#/?v=2-7&smp=" <> url (queueNewStr <> ";" <> queueNewStr) <> "&e2e=" <> testE2ERatchetParamsStrUri) connectionRequestV1 #== ("https://simplex.chat/invitation#/?v=1&smp=" <> url queueStr <> "&e2e=" <> testE2ERatchetParamsStrUri) - connectionRequestClientDataEmpty #==# ("simplex:/invitation#/?v=2-6&smp=" <> url queueStr <> "&e2e=" <> testE2ERatchetParamsStrUri <> "&data=" <> url "{}") - contactAddress #==# ("simplex:/contact#/?v=2-6&smp=" <> url queueStr) - contactAddress #== ("https://simplex.chat/contact#/?v=2-6&smp=" <> url queueStr) - contactAddress2queues #==# ("simplex:/contact#/?v=2-6&smp=" <> url (queueStr <> ";" <> queueStr)) - contactAddressNew #==# ("simplex:/contact#/?v=2-6&smp=" <> url queueNewStr) - contactAddress2queuesNew #==# ("simplex:/contact#/?v=2-6&smp=" <> url (queueNewStr <> ";" <> queueNewStr)) + connectionRequestClientDataEmpty #==# ("simplex:/invitation#/?v=2-7&smp=" <> url queueStr <> "&e2e=" <> testE2ERatchetParamsStrUri <> "&data=" <> url "{}") + contactAddress #==# ("simplex:/contact#/?v=2-7&smp=" <> url queueStr) + contactAddress #== ("https://simplex.chat/contact#/?v=2-7&smp=" <> url queueStr) + contactAddress2queues #==# ("simplex:/contact#/?v=2-7&smp=" <> url (queueStr <> ";" <> queueStr)) + contactAddressNew #==# ("simplex:/contact#/?v=2-7&smp=" <> url queueNewStr) + contactAddress2queuesNew #==# ("simplex:/contact#/?v=2-7&smp=" <> url (queueNewStr <> ";" <> queueNewStr)) contactAddressV2 #==# ("simplex:/contact#/?v=2&smp=" <> url queueStr) contactAddressV2 #== ("https://simplex.chat/contact#/?v=1&smp=" <> url queueStr) -- adjusted to v2 contactAddressV2 #== ("https://simplex.chat/contact#/?v=1-2&smp=" <> url queueStr) -- adjusted to v2 contactAddressV2 #== ("https://simplex.chat/contact#/?v=2-2&smp=" <> url queueStr) - contactAddressClientData #==# ("simplex:/contact#/?v=2-6&smp=" <> url queueStr <> "&data=" <> url "{\"type\":\"group_link\", \"group_link_id\":\"abc\"}") + contactAddressClientData #==# ("simplex:/contact#/?v=2-7&smp=" <> url queueStr <> "&data=" <> url "{\"type\":\"group_link\", \"group_link_id\":\"abc\"}") diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index 3f70ad6ab..4d61d8463 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -38,6 +38,7 @@ module AgentTests.FunctionalAPITests rfGet, sfGet, nGet, + getInAnyOrder, (##>), (=##>), pattern CON, @@ -244,7 +245,7 @@ inAnyOrder g rs = withFrozenCallStack $ do createConnection :: AgentClient -> UserId -> Bool -> SConnectionMode c -> Maybe CRClientData -> SubscriptionMode -> AE (ConnId, ConnectionRequestUri c) createConnection c userId enableNtfs cMode clientData = A.createConnection c userId enableNtfs cMode clientData (IKNoPQ PQSupportOn) -joinConnection :: AgentClient -> UserId -> Bool -> ConnectionRequestUri c -> ConnInfo -> SubscriptionMode -> AE ConnId +joinConnection :: AgentClient -> UserId -> Bool -> ConnectionRequestUri c -> ConnInfo -> SubscriptionMode -> AE (ConnId, SndQueueSecured) joinConnection c userId enableNtfs cReq connInfo = A.joinConnection c userId Nothing enableNtfs cReq connInfo PQSupportOn sendMessage :: AgentClient -> ConnId -> SMP.MsgFlags -> MsgBody -> AE AgentMsgId @@ -269,13 +270,13 @@ functionalAPITests t = do describe "two way concurrently (50)" $ testMatrix2Stress t $ runAgentClientStressTestConc 25 xdescribe "two way concurrently (1000)" $ testMatrix2Stress t $ runAgentClientStressTestConc 500 describe "Establishing duplex connection, different PQ settings" $ do - testPQMatrix2 t $ runAgentClientTestPQ True + testPQMatrix2 t $ runAgentClientTestPQ False True describe "Establishing duplex connection v2, different Ratchet versions" $ testRatchetMatrix2 t runAgentClientTest describe "Establish duplex connection via contact address" $ testMatrix2 t runAgentClientContactTest describe "Establish duplex connection via contact address, different PQ settings" $ do - testPQMatrix2NoInv t $ runAgentClientContactTestPQ True PQSupportOn + testPQMatrix2NoInv t $ runAgentClientContactTestPQ False True PQSupportOn describe "Establish duplex connection via contact address v2, different Ratchet versions" $ testRatchetMatrix2 t runAgentClientContactTest describe "Establish duplex connection via contact address, different PQ settings" $ do @@ -356,6 +357,9 @@ functionalAPITests t = do it "should subscribe to multiple connections with pending messages" $ withSmpServer t $ testBatchedPendingMessages 10 5 + describe "Batch send messages" $ do + it "should send multiple messages to the same connection" $ withSmpServer t testSendMessagesB + it "should send messages to the 2 connections" $ withSmpServer t testSendMessagesB2 describe "Async agent commands" $ do describe "connect using async agent commands" $ testBasicMatrix2 t testAsyncCommands @@ -410,29 +414,30 @@ functionalAPITests t = do let v4 = prevVersion basicAuthSMPVersion forM_ (nub [prevVersion authCmdsSMPVersion, authCmdsSMPVersion, currentServerSMPRelayVersion]) $ \v -> do let baseId = if v >= sndAuthKeySMPVersion then 1 else 3 + sqSecured = if v >= sndAuthKeySMPVersion then True else False describe ("v" <> show v <> ": with server auth") $ do -- allow NEW | server auth, v | clnt1 auth, v | clnt2 auth, v | 2 - success, 1 - JOIN fail, 0 - NEW fail - it "success " $ testBasicAuth t True (Just "abcd", v) (Just "abcd", v) (Just "abcd", v) baseId `shouldReturn` 2 - it "disabled " $ testBasicAuth t False (Just "abcd", v) (Just "abcd", v) (Just "abcd", v) baseId `shouldReturn` 0 - it "NEW fail, no auth " $ testBasicAuth t True (Just "abcd", v) (Nothing, v) (Just "abcd", v) baseId `shouldReturn` 0 - it "NEW fail, bad auth " $ testBasicAuth t True (Just "abcd", v) (Just "wrong", v) (Just "abcd", v) baseId `shouldReturn` 0 - it "NEW fail, version " $ testBasicAuth t True (Just "abcd", v) (Just "abcd", v4) (Just "abcd", v) baseId `shouldReturn` 0 - it "JOIN fail, no auth " $ testBasicAuth t True (Just "abcd", v) (Just "abcd", v) (Nothing, v) baseId `shouldReturn` 1 - it "JOIN fail, bad auth " $ testBasicAuth t True (Just "abcd", v) (Just "abcd", v) (Just "wrong", v) baseId `shouldReturn` 1 - it "JOIN fail, version " $ testBasicAuth t True (Just "abcd", v) (Just "abcd", v) (Just "abcd", v4) baseId `shouldReturn` 1 + it "success " $ testBasicAuth t True (Just "abcd", v) (Just "abcd", v) (Just "abcd", v) sqSecured baseId `shouldReturn` 2 + it "disabled " $ testBasicAuth t False (Just "abcd", v) (Just "abcd", v) (Just "abcd", v) sqSecured baseId `shouldReturn` 0 + it "NEW fail, no auth " $ testBasicAuth t True (Just "abcd", v) (Nothing, v) (Just "abcd", v) sqSecured baseId `shouldReturn` 0 + it "NEW fail, bad auth " $ testBasicAuth t True (Just "abcd", v) (Just "wrong", v) (Just "abcd", v) sqSecured baseId `shouldReturn` 0 + it "NEW fail, version " $ testBasicAuth t True (Just "abcd", v) (Just "abcd", v4) (Just "abcd", v) sqSecured baseId `shouldReturn` 0 + it "JOIN fail, no auth " $ testBasicAuth t True (Just "abcd", v) (Just "abcd", v) (Nothing, v) sqSecured baseId `shouldReturn` 1 + it "JOIN fail, bad auth " $ testBasicAuth t True (Just "abcd", v) (Just "abcd", v) (Just "wrong", v) sqSecured baseId `shouldReturn` 1 + it "JOIN fail, version " $ testBasicAuth t True (Just "abcd", v) (Just "abcd", v) (Just "abcd", v4) sqSecured baseId `shouldReturn` 1 describe ("v" <> show v <> ": no server auth") $ do - it "success " $ testBasicAuth t True (Nothing, v) (Nothing, v) (Nothing, v) baseId `shouldReturn` 2 - it "srv disabled" $ testBasicAuth t False (Nothing, v) (Nothing, v) (Nothing, v) baseId `shouldReturn` 0 - it "version srv " $ testBasicAuth t True (Nothing, v4) (Nothing, v) (Nothing, v) 3 `shouldReturn` 2 - it "version fst " $ testBasicAuth t True (Nothing, v) (Nothing, v4) (Nothing, v) baseId `shouldReturn` 2 - it "version snd " $ testBasicAuth t True (Nothing, v) (Nothing, v) (Nothing, v4) 3 `shouldReturn` 2 - it "version both" $ testBasicAuth t True (Nothing, v) (Nothing, v4) (Nothing, v4) 3 `shouldReturn` 2 - it "version all " $ testBasicAuth t True (Nothing, v4) (Nothing, v4) (Nothing, v4) 3 `shouldReturn` 2 - it "auth fst " $ testBasicAuth t True (Nothing, v) (Just "abcd", v) (Nothing, v) baseId `shouldReturn` 2 - it "auth fst 2 " $ testBasicAuth t True (Nothing, v4) (Just "abcd", v) (Nothing, v) 3 `shouldReturn` 2 - it "auth snd " $ testBasicAuth t True (Nothing, v) (Nothing, v) (Just "abcd", v) baseId `shouldReturn` 2 - it "auth both " $ testBasicAuth t True (Nothing, v) (Just "abcd", v) (Just "abcd", v) baseId `shouldReturn` 2 - it "auth, disabled" $ testBasicAuth t False (Nothing, v) (Just "abcd", v) (Just "abcd", v) baseId `shouldReturn` 0 + it "success " $ testBasicAuth t True (Nothing, v) (Nothing, v) (Nothing, v) sqSecured baseId `shouldReturn` 2 + it "srv disabled" $ testBasicAuth t False (Nothing, v) (Nothing, v) (Nothing, v) sqSecured baseId `shouldReturn` 0 + it "version srv " $ testBasicAuth t True (Nothing, v4) (Nothing, v) (Nothing, v) False 3 `shouldReturn` 2 + it "version fst " $ testBasicAuth t True (Nothing, v) (Nothing, v4) (Nothing, v) False baseId `shouldReturn` 2 + it "version snd " $ testBasicAuth t True (Nothing, v) (Nothing, v) (Nothing, v4) sqSecured 3 `shouldReturn` 2 + it "version both" $ testBasicAuth t True (Nothing, v) (Nothing, v4) (Nothing, v4) False 3 `shouldReturn` 2 + it "version all " $ testBasicAuth t True (Nothing, v4) (Nothing, v4) (Nothing, v4) False 3 `shouldReturn` 2 + it "auth fst " $ testBasicAuth t True (Nothing, v) (Just "abcd", v) (Nothing, v) sqSecured baseId `shouldReturn` 2 + it "auth fst 2 " $ testBasicAuth t True (Nothing, v4) (Just "abcd", v) (Nothing, v) False 3 `shouldReturn` 2 + it "auth snd " $ testBasicAuth t True (Nothing, v) (Nothing, v) (Just "abcd", v) sqSecured baseId `shouldReturn` 2 + it "auth both " $ testBasicAuth t True (Nothing, v) (Just "abcd", v) (Just "abcd", v) sqSecured baseId `shouldReturn` 2 + it "auth, disabled" $ testBasicAuth t False (Nothing, v) (Just "abcd", v) (Just "abcd", v) sqSecured baseId `shouldReturn` 0 describe "SMP server test via agent API" $ do it "should pass without basic auth" $ testSMPServerConnectionTest t Nothing (noAuthSrv testSMPServer2) `shouldReturn` Nothing let srv1 = testSMPServer2 {keyHash = "1234"} @@ -460,8 +465,8 @@ functionalAPITests t = do it "server should respond with queue and subscription information" $ withSmpServer t testServerQueueInfo -testBasicAuth :: ATransport -> Bool -> (Maybe BasicAuth, VersionSMP) -> (Maybe BasicAuth, VersionSMP) -> (Maybe BasicAuth, VersionSMP) -> AgentMsgId -> IO Int -testBasicAuth t allowNewQueues srv@(srvAuth, srvVersion) clnt1 clnt2 baseId = do +testBasicAuth :: ATransport -> Bool -> (Maybe BasicAuth, VersionSMP) -> (Maybe BasicAuth, VersionSMP) -> (Maybe BasicAuth, VersionSMP) -> SndQueueSecured -> AgentMsgId -> IO Int +testBasicAuth t allowNewQueues srv@(srvAuth, srvVersion) clnt1 clnt2 sqSecured baseId = do let testCfg = cfg {allowNewQueues, newQueueBasicAuth = srvAuth, smpServerVRange = V.mkVersionRange batchCmdsSMPVersion srvVersion} canCreate1 = canCreateQueue allowNewQueues srv clnt1 canCreate2 = canCreateQueue allowNewQueues srv clnt2 @@ -469,7 +474,7 @@ testBasicAuth t allowNewQueues srv@(srvAuth, srvVersion) clnt1 clnt2 baseId = do | canCreate1 && canCreate2 = 2 | canCreate1 = 1 | otherwise = 0 - created <- withSmpServerConfigOn t testCfg testPort $ \_ -> testCreateQueueAuth srvVersion clnt1 clnt2 baseId + created <- withSmpServerConfigOn t testCfg testPort $ \_ -> testCreateQueueAuth srvVersion clnt1 clnt2 sqSecured baseId created `shouldBe` expected pure created @@ -478,43 +483,43 @@ canCreateQueue allowNew (srvAuth, srvVersion) (clntAuth, clntVersion) = let v = basicAuthSMPVersion in allowNew && (isNothing srvAuth || (srvVersion >= v && clntVersion >= v && srvAuth == clntAuth)) -testMatrix2 :: HasCallStack => ATransport -> (PQSupport -> Bool -> AgentClient -> AgentClient -> AgentMsgId -> IO ()) -> Spec +testMatrix2 :: HasCallStack => ATransport -> (PQSupport -> SndQueueSecured -> Bool -> AgentClient -> AgentClient -> AgentMsgId -> IO ()) -> Spec testMatrix2 t runTest = do - it "current, via proxy" $ withSmpServerProxy t $ runTestCfgServers2 agentCfg agentCfg (initAgentServersProxy SPMAlways SPFProhibit) 1 $ runTest PQSupportOn True - it "v8, via proxy" $ withSmpServerProxy t $ runTestCfgServers2 agentProxyCfgV8 agentProxyCfgV8 (initAgentServersProxy SPMAlways SPFProhibit) 3 $ runTest PQSupportOn True - it "current" $ withSmpServer t $ runTestCfg2 agentCfg agentCfg 1 $ runTest PQSupportOn False - it "prev" $ withSmpServer t $ runTestCfg2 agentCfgVPrev agentCfgVPrev 3 $ runTest PQSupportOff False - it "prev to current" $ withSmpServer t $ runTestCfg2 agentCfgVPrev agentCfg 3 $ runTest PQSupportOff False - it "current to prev" $ withSmpServer t $ runTestCfg2 agentCfg agentCfgVPrev 3 $ runTest PQSupportOff False - -testMatrix2Stress :: HasCallStack => ATransport -> (PQSupport -> Bool -> AgentClient -> AgentClient -> AgentMsgId -> IO ()) -> Spec + it "current, via proxy" $ withSmpServerProxy t $ runTestCfgServers2 agentCfg agentCfg (initAgentServersProxy SPMAlways SPFProhibit) 1 $ runTest PQSupportOn True True + it "v8, via proxy" $ withSmpServerProxy t $ runTestCfgServers2 agentProxyCfgV8 agentProxyCfgV8 (initAgentServersProxy SPMAlways SPFProhibit) 3 $ runTest PQSupportOn False True + it "current" $ withSmpServer t $ runTestCfg2 agentCfg agentCfg 1 $ runTest PQSupportOn True False + it "prev" $ withSmpServer t $ runTestCfg2 agentCfgVPrev agentCfgVPrev 3 $ runTest PQSupportOff False False + it "prev to current" $ withSmpServer t $ runTestCfg2 agentCfgVPrev agentCfg 3 $ runTest PQSupportOff False False + it "current to prev" $ withSmpServer t $ runTestCfg2 agentCfg agentCfgVPrev 3 $ runTest PQSupportOff False False + +testMatrix2Stress :: HasCallStack => ATransport -> (PQSupport -> SndQueueSecured -> Bool -> AgentClient -> AgentClient -> AgentMsgId -> IO ()) -> Spec testMatrix2Stress t runTest = do - it "current, via proxy" $ withSmpServerProxy t $ runTestCfgServers2 aCfg aCfg (initAgentServersProxy SPMAlways SPFProhibit) 1 $ runTest PQSupportOn True - it "v8, via proxy" $ withSmpServerProxy t $ runTestCfgServers2 aProxyCfgV8 aProxyCfgV8 (initAgentServersProxy SPMAlways SPFProhibit) 3 $ runTest PQSupportOn True - it "current" $ withSmpServer t $ runTestCfg2 aCfg aCfg 1 $ runTest PQSupportOn False - it "prev" $ withSmpServer t $ runTestCfg2 aCfgVPrev aCfgVPrev 3 $ runTest PQSupportOff False - it "prev to current" $ withSmpServer t $ runTestCfg2 aCfgVPrev aCfg 3 $ runTest PQSupportOff False - it "current to prev" $ withSmpServer t $ runTestCfg2 aCfg aCfgVPrev 3 $ runTest PQSupportOff False + it "current, via proxy" $ withSmpServerProxy t $ runTestCfgServers2 aCfg aCfg (initAgentServersProxy SPMAlways SPFProhibit) 1 $ runTest PQSupportOn True True + it "v8, via proxy" $ withSmpServerProxy t $ runTestCfgServers2 aProxyCfgV8 aProxyCfgV8 (initAgentServersProxy SPMAlways SPFProhibit) 3 $ runTest PQSupportOn False True + it "current" $ withSmpServer t $ runTestCfg2 aCfg aCfg 1 $ runTest PQSupportOn True False + it "prev" $ withSmpServer t $ runTestCfg2 aCfgVPrev aCfgVPrev 3 $ runTest PQSupportOff False False + it "prev to current" $ withSmpServer t $ runTestCfg2 aCfgVPrev aCfg 3 $ runTest PQSupportOff False False + it "current to prev" $ withSmpServer t $ runTestCfg2 aCfg aCfgVPrev 3 $ runTest PQSupportOff False False where aCfg = agentCfg {messageRetryInterval = fastMessageRetryInterval} aProxyCfgV8 = agentProxyCfgV8 {messageRetryInterval = fastMessageRetryInterval} aCfgVPrev = agentCfgVPrev {messageRetryInterval = fastMessageRetryInterval} -testBasicMatrix2 :: HasCallStack => ATransport -> (AgentClient -> AgentClient -> AgentMsgId -> IO ()) -> Spec +testBasicMatrix2 :: HasCallStack => ATransport -> (SndQueueSecured -> AgentClient -> AgentClient -> AgentMsgId -> IO ()) -> Spec testBasicMatrix2 t runTest = do - it "current" $ withSmpServer t $ runTestCfg2 agentCfg agentCfg 1 $ runTest - it "prev" $ withSmpServer t $ runTestCfg2 agentCfgVPrevPQ agentCfgVPrevPQ 3 $ runTest - it "prev to current" $ withSmpServer t $ runTestCfg2 agentCfgVPrevPQ agentCfg 3 $ runTest - it "current to prev" $ withSmpServer t $ runTestCfg2 agentCfg agentCfgVPrevPQ 3 $ runTest + it "current" $ withSmpServer t $ runTestCfg2 agentCfg agentCfg 1 $ runTest True + it "prev" $ withSmpServer t $ runTestCfg2 agentCfgVPrevPQ agentCfgVPrevPQ 3 $ runTest False + it "prev to current" $ withSmpServer t $ runTestCfg2 agentCfgVPrevPQ agentCfg 3 $ runTest False + it "current to prev" $ withSmpServer t $ runTestCfg2 agentCfg agentCfgVPrevPQ 3 $ runTest False -testRatchetMatrix2 :: HasCallStack => ATransport -> (PQSupport -> Bool -> AgentClient -> AgentClient -> AgentMsgId -> IO ()) -> Spec +testRatchetMatrix2 :: HasCallStack => ATransport -> (PQSupport -> SndQueueSecured -> Bool -> AgentClient -> AgentClient -> AgentMsgId -> IO ()) -> Spec testRatchetMatrix2 t runTest = do - it "current, via proxy" $ withSmpServerProxy t $ runTestCfgServers2 agentCfg agentCfg (initAgentServersProxy SPMAlways SPFProhibit) 1 $ runTest PQSupportOn True - it "v8, via proxy" $ withSmpServerProxy t $ runTestCfgServers2 agentProxyCfgV8 agentProxyCfgV8 (initAgentServersProxy SPMAlways SPFProhibit) 3 $ runTest PQSupportOn True - it "ratchet current" $ withSmpServer t $ runTestCfg2 agentCfg agentCfg 1 $ runTest PQSupportOn False - it "ratchet prev" $ withSmpServer t $ runTestCfg2 agentCfgRatchetVPrev agentCfgRatchetVPrev 1 $ runTest PQSupportOff False - it "ratchets prev to current" $ withSmpServer t $ runTestCfg2 agentCfgRatchetVPrev agentCfg 1 $ runTest PQSupportOff False - it "ratchets current to prev" $ withSmpServer t $ runTestCfg2 agentCfg agentCfgRatchetVPrev 1 $ runTest PQSupportOff False + it "current, via proxy" $ withSmpServerProxy t $ runTestCfgServers2 agentCfg agentCfg (initAgentServersProxy SPMAlways SPFProhibit) 1 $ runTest PQSupportOn True True + it "v8, via proxy" $ withSmpServerProxy t $ runTestCfgServers2 agentProxyCfgV8 agentProxyCfgV8 (initAgentServersProxy SPMAlways SPFProhibit) 3 $ runTest PQSupportOn False True + it "ratchet current" $ withSmpServer t $ runTestCfg2 agentCfg agentCfg 1 $ runTest PQSupportOn True False + it "ratchet prev" $ withSmpServer t $ runTestCfg2 agentCfgRatchetVPrev agentCfgRatchetVPrev 1 $ runTest PQSupportOff True False + it "ratchets prev to current" $ withSmpServer t $ runTestCfg2 agentCfgRatchetVPrev agentCfg 1 $ runTest PQSupportOff True False + it "ratchets current to prev" $ withSmpServer t $ runTestCfg2 agentCfg agentCfgRatchetVPrev 1 $ runTest PQSupportOff True False testServerMatrix2 :: HasCallStack => ATransport -> (InitialAgentServers -> IO ()) -> Spec testServerMatrix2 t runTest = do @@ -589,15 +594,16 @@ withAgentClients3 runTest = withAgent 3 agentCfg initAgentServers testDB3 $ \c -> runTest a b c -runAgentClientTest :: HasCallStack => PQSupport -> Bool -> AgentClient -> AgentClient -> AgentMsgId -> IO () -runAgentClientTest pqSupport viaProxy alice bob baseId = - runAgentClientTestPQ viaProxy (alice, IKNoPQ pqSupport) (bob, pqSupport) baseId +runAgentClientTest :: HasCallStack => PQSupport -> SndQueueSecured -> Bool -> AgentClient -> AgentClient -> AgentMsgId -> IO () +runAgentClientTest pqSupport sqSecured viaProxy alice bob baseId = + runAgentClientTestPQ sqSecured viaProxy (alice, IKNoPQ pqSupport) (bob, pqSupport) baseId -runAgentClientTestPQ :: HasCallStack => Bool -> (AgentClient, InitialKeys) -> (AgentClient, PQSupport) -> AgentMsgId -> IO () -runAgentClientTestPQ viaProxy (alice, aPQ) (bob, bPQ) baseId = +runAgentClientTestPQ :: HasCallStack => SndQueueSecured -> Bool -> (AgentClient, InitialKeys) -> (AgentClient, PQSupport) -> AgentMsgId -> IO () +runAgentClientTestPQ sqSecured viaProxy (alice, aPQ) (bob, bPQ) baseId = runRight_ $ do (bobId, qInfo) <- A.createConnection alice 1 True SCMInvitation Nothing aPQ SMSubscribe - aliceId <- A.joinConnection bob 1 Nothing True qInfo "bob's connInfo" bPQ SMSubscribe + (aliceId, sqSecured') <- A.joinConnection bob 1 Nothing True qInfo "bob's connInfo" bPQ SMSubscribe + liftIO $ sqSecured' `shouldBe` sqSecured ("", _, A.CONF confId pqSup' _ "bob's connInfo") <- get alice liftIO $ pqSup' `shouldBe` CR.connPQEncryption aPQ allowConnection alice bobId confId "alice's connInfo" @@ -634,10 +640,10 @@ runAgentClientTestPQ viaProxy (alice, aPQ) (bob, bPQ) baseId = pqConnectionMode :: InitialKeys -> PQSupport -> Bool pqConnectionMode pqMode1 pqMode2 = supportPQ (CR.connPQEncryption pqMode1) && supportPQ pqMode2 -runAgentClientStressTestOneWay :: HasCallStack => Int64 -> PQSupport -> Bool -> AgentClient -> AgentClient -> AgentMsgId -> IO () -runAgentClientStressTestOneWay n pqSupport viaProxy alice bob baseId = runRight_ $ do +runAgentClientStressTestOneWay :: HasCallStack => Int64 -> PQSupport -> SndQueueSecured -> Bool -> AgentClient -> AgentClient -> AgentMsgId -> IO () +runAgentClientStressTestOneWay n pqSupport sqSecured viaProxy alice bob baseId = runRight_ $ do let pqEnc = PQEncryption $ supportPQ pqSupport - (aliceId, bobId) <- makeConnection_ pqSupport alice bob + (aliceId, bobId) <- makeConnection_ pqSupport sqSecured alice bob let proxySrv = if viaProxy then Just testSMPServer else Nothing message i = "message " <> bshow i concurrently_ @@ -666,10 +672,10 @@ runAgentClientStressTestOneWay n pqSupport viaProxy alice bob baseId = runRight_ where msgId = subtract baseId . fst -runAgentClientStressTestConc :: HasCallStack => Int64 -> PQSupport -> Bool -> AgentClient -> AgentClient -> AgentMsgId -> IO () -runAgentClientStressTestConc n pqSupport viaProxy alice bob baseId = runRight_ $ do +runAgentClientStressTestConc :: HasCallStack => Int64 -> PQSupport -> SndQueueSecured -> Bool -> AgentClient -> AgentClient -> AgentMsgId -> IO () +runAgentClientStressTestConc n pqSupport sqSecured viaProxy alice bob baseId = runRight_ $ do let pqEnc = PQEncryption $ supportPQ pqSupport - (aliceId, bobId) <- makeConnection_ pqSupport alice bob + (aliceId, bobId) <- makeConnection_ pqSupport sqSecured alice bob let proxySrv = if viaProxy then Just testSMPServer else Nothing message i = "message " <> bshow i loop a bId mIdVar i = do @@ -703,7 +709,7 @@ testEnablePQEncryption :: HasCallStack => IO () testEnablePQEncryption = withAgentClients2 $ \ca cb -> runRight_ $ do g <- liftIO C.newRandom - (aId, bId) <- makeConnection_ PQSupportOff ca cb + (aId, bId) <- makeConnection_ PQSupportOff True ca cb let a = (ca, aId) b = (cb, bId) (a, 2, "msg 1") \#>\ b @@ -789,20 +795,23 @@ testAgentClient3 = get c =##> \case ("", connId, Msg "c5") -> connId == aIdForC; _ -> False ackMessage c aIdForC 3 Nothing -runAgentClientContactTest :: HasCallStack => PQSupport -> Bool -> AgentClient -> AgentClient -> AgentMsgId -> IO () -runAgentClientContactTest pqSupport viaProxy alice bob baseId = - runAgentClientContactTestPQ viaProxy pqSupport (alice, IKNoPQ pqSupport) (bob, pqSupport) baseId +runAgentClientContactTest :: HasCallStack => PQSupport -> SndQueueSecured -> Bool -> AgentClient -> AgentClient -> AgentMsgId -> IO () +runAgentClientContactTest pqSupport sqSecured viaProxy alice bob baseId = + runAgentClientContactTestPQ sqSecured viaProxy pqSupport (alice, IKNoPQ pqSupport) (bob, pqSupport) baseId -runAgentClientContactTestPQ :: HasCallStack => Bool -> PQSupport -> (AgentClient, InitialKeys) -> (AgentClient, PQSupport) -> AgentMsgId -> IO () -runAgentClientContactTestPQ viaProxy reqPQSupport (alice, aPQ) (bob, bPQ) baseId = +runAgentClientContactTestPQ :: HasCallStack => SndQueueSecured -> Bool -> PQSupport -> (AgentClient, InitialKeys) -> (AgentClient, PQSupport) -> AgentMsgId -> IO () +runAgentClientContactTestPQ sqSecured viaProxy reqPQSupport (alice, aPQ) (bob, bPQ) baseId = runRight_ $ do (_, qInfo) <- A.createConnection alice 1 True SCMContact Nothing aPQ SMSubscribe aliceId <- A.prepareConnectionToJoin bob 1 True qInfo bPQ - aliceId' <- A.joinConnection bob 1 (Just aliceId) True qInfo "bob's connInfo" bPQ SMSubscribe - liftIO $ aliceId' `shouldBe` aliceId + (aliceId', sqSecuredJoin) <- A.joinConnection bob 1 (Just aliceId) True qInfo "bob's connInfo" bPQ SMSubscribe + liftIO $ do + aliceId' `shouldBe` aliceId + sqSecuredJoin `shouldBe` False -- joining via contact address connection ("", _, A.REQ invId pqSup' _ "bob's connInfo") <- get alice liftIO $ pqSup' `shouldBe` reqPQSupport - bobId <- acceptContact alice True invId "alice's connInfo" (CR.connPQEncryption aPQ) SMSubscribe + (bobId, sqSecured') <- acceptContact alice True invId "alice's connInfo" (CR.connPQEncryption aPQ) SMSubscribe + liftIO $ sqSecured' `shouldBe` sqSecured ("", _, A.CONF confId pqSup'' _ "alice's connInfo") <- get bob liftIO $ pqSup'' `shouldBe` bPQ allowConnection bob aliceId confId "bob's connInfo" @@ -847,11 +856,14 @@ runAgentClientContactTestPQ3 viaProxy (alice, aPQ) (bob, bPQ) (tom, tPQ) baseId msgId = subtract baseId . fst connectViaContact b pq qInfo = do aId <- A.prepareConnectionToJoin b 1 True qInfo pq - aId' <- A.joinConnection b 1 (Just aId) True qInfo "bob's connInfo" pq SMSubscribe - liftIO $ aId' `shouldBe` aId + (aId', sqSecuredJoin) <- A.joinConnection b 1 (Just aId) True qInfo "bob's connInfo" pq SMSubscribe + liftIO $ do + aId' `shouldBe` aId + sqSecuredJoin `shouldBe` False -- joining via contact address connection ("", _, A.REQ invId pqSup' _ "bob's connInfo") <- get alice liftIO $ pqSup' `shouldBe` PQSupportOn - bId <- acceptContact alice True invId "alice's connInfo" (CR.connPQEncryption aPQ) SMSubscribe + (bId, sqSecuredAccept) <- acceptContact alice True invId "alice's connInfo" (CR.connPQEncryption aPQ) SMSubscribe + liftIO $ sqSecuredAccept `shouldBe` False -- agent cfg is v8 ("", _, A.CONF confId pqSup'' _ "alice's connInfo") <- get b liftIO $ pqSup'' `shouldBe` pq allowConnection b aId confId "bob's connInfo" @@ -891,8 +903,10 @@ testRejectContactRequest = withAgentClients2 $ \alice bob -> runRight_ $ do (addrConnId, qInfo) <- A.createConnection alice 1 True SCMContact Nothing IKPQOn SMSubscribe aliceId <- A.prepareConnectionToJoin bob 1 True qInfo PQSupportOn - aliceId' <- A.joinConnection bob 1 (Just aliceId) True qInfo "bob's connInfo" PQSupportOn SMSubscribe - liftIO $ aliceId' `shouldBe` aliceId + (aliceId', sqSecured) <- A.joinConnection bob 1 (Just aliceId) True qInfo "bob's connInfo" PQSupportOn SMSubscribe + liftIO $ do + aliceId' `shouldBe` aliceId + sqSecured `shouldBe` False -- joining via contact address connection ("", _, A.REQ invId PQSupportOn _ "bob's connInfo") <- get alice liftIO $ runExceptT (rejectContact alice "abcd" invId) `shouldReturn` Left (CONN NOT_FOUND) rejectContact alice addrConnId invId @@ -904,15 +918,34 @@ testAsyncInitiatingOffline = alice <- liftIO $ getSMPAgentClient' 1 agentCfg initAgentServers testDB (bobId, cReq) <- createConnection alice 1 True SCMInvitation Nothing SMSubscribe liftIO $ disposeAgentClient alice - aliceId <- joinConnection bob 1 True cReq "bob's connInfo" SMSubscribe + + (aliceId, sqSecured) <- joinConnection bob 1 True cReq "bob's connInfo" SMSubscribe + liftIO $ sqSecured `shouldBe` True + + -- send messages + msgId1 <- A.sendMessage bob aliceId PQEncOn SMP.noMsgFlags "can send 1" + liftIO $ msgId1 `shouldBe` (2, PQEncOff) + get bob ##> ("", aliceId, SENT 2) + msgId2 <- A.sendMessage bob aliceId PQEncOn SMP.noMsgFlags "can send 2" + liftIO $ msgId2 `shouldBe` (3, PQEncOff) + get bob ##> ("", aliceId, SENT 3) + alice' <- liftIO $ getSMPAgentClient' 3 agentCfg initAgentServers testDB subscribeConnection alice' bobId ("", _, CONF confId _ "bob's connInfo") <- get alice' + -- receive messages + get alice' =##> \case ("", c, Msg' mId pq "can send 1") -> c == bobId && mId == 1 && pq == PQEncOff; _ -> False + ackMessage alice' bobId 1 Nothing + get alice' =##> \case ("", c, Msg' mId pq "can send 2") -> c == bobId && mId == 2 && pq == PQEncOff; _ -> False + ackMessage alice' bobId 2 Nothing + -- for alice msg id 3 is sent confirmation, then they're matched with bob at msg id 4 + + -- allow connection allowConnection alice' bobId confId "alice's connInfo" get alice' ##> ("", bobId, CON) get bob ##> ("", aliceId, INFO "alice's connInfo") get bob ##> ("", aliceId, CON) - exchangeGreetings alice' bobId bob aliceId + exchangeGreetingsMsgId 4 alice' bobId bob aliceId liftIO $ disposeAgentClient alice' testAsyncJoiningOfflineBeforeActivation :: HasCallStack => IO () @@ -920,7 +953,8 @@ testAsyncJoiningOfflineBeforeActivation = withAgent 1 agentCfg initAgentServers testDB $ \alice -> runRight_ $ do bob <- liftIO $ getSMPAgentClient' 2 agentCfg initAgentServers testDB2 (bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing SMSubscribe - aliceId <- joinConnection bob 1 True qInfo "bob's connInfo" SMSubscribe + (aliceId, sqSecured) <- joinConnection bob 1 True qInfo "bob's connInfo" SMSubscribe + liftIO $ sqSecured `shouldBe` True liftIO $ disposeAgentClient bob ("", _, CONF confId _ "bob's connInfo") <- get alice allowConnection alice bobId confId "alice's connInfo" @@ -939,7 +973,8 @@ testAsyncBothOffline = do runRight_ $ do (bobId, cReq) <- createConnection alice 1 True SCMInvitation Nothing SMSubscribe liftIO $ disposeAgentClient alice - aliceId <- joinConnection bob 1 True cReq "bob's connInfo" SMSubscribe + (aliceId, sqSecured) <- joinConnection bob 1 True cReq "bob's connInfo" SMSubscribe + liftIO $ sqSecured `shouldBe` True liftIO $ disposeAgentClient bob alice' <- liftIO $ getSMPAgentClient' 3 agentCfg initAgentServers testDB subscribeConnection alice' bobId @@ -970,7 +1005,8 @@ testAsyncServerOffline t = withAgentClients2 $ \alice bob -> do liftIO $ do srv1 `shouldBe` testSMPServer conns1 `shouldBe` [bobId] - aliceId <- joinConnection bob 1 True cReq "bob's connInfo" SMSubscribe + (aliceId, sqSecured) <- joinConnection bob 1 True cReq "bob's connInfo" SMSubscribe + liftIO $ sqSecured `shouldBe` True ("", _, CONF confId _ "bob's connInfo") <- get alice allowConnection alice bobId confId "alice's connInfo" get alice ##> ("", bobId, CON) @@ -988,7 +1024,8 @@ testAllowConnectionClientRestart t = do withSmpServerConfigOn t cfg {storeLogFile = Just testStoreLogFile2} testPort2 $ \_ -> do runRight $ do (bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing SMSubscribe - aliceId <- joinConnection bob 1 True qInfo "bob's connInfo" SMSubscribe + (aliceId, sqSecured) <- joinConnection bob 1 True qInfo "bob's connInfo" SMSubscribe + liftIO $ sqSecured `shouldBe` True ("", _, CONF confId _ "bob's connInfo") <- get alice pure (aliceId, bobId, confId) @@ -1024,7 +1061,7 @@ testIncreaseConnAgentVersion t = do bob <- getSMPAgentClient' 2 agentCfg {smpAgentVRange = mkVersionRange 1 2} initAgentServers testDB2 withSmpServerStoreMsgLogOn t testPort $ \_ -> do (aliceId, bobId) <- runRight $ do - (aliceId, bobId) <- makeConnection_ PQSupportOff alice bob + (aliceId, bobId) <- makeConnection_ PQSupportOff False alice bob exchangeGreetingsMsgId_ PQEncOff 2 alice bobId bob aliceId checkVersion alice bobId 2 checkVersion bob aliceId 2 @@ -1089,7 +1126,7 @@ testIncreaseConnAgentVersionMaxCompatible t = do bob <- getSMPAgentClient' 2 agentCfg {smpAgentVRange = mkVersionRange 1 2} initAgentServers testDB2 withSmpServerStoreMsgLogOn t testPort $ \_ -> do (aliceId, bobId) <- runRight $ do - (aliceId, bobId) <- makeConnection_ PQSupportOff alice bob + (aliceId, bobId) <- makeConnection_ PQSupportOff False alice bob exchangeGreetingsMsgId_ PQEncOff 2 alice bobId bob aliceId checkVersion alice bobId 2 checkVersion bob aliceId 2 @@ -1119,7 +1156,7 @@ testIncreaseConnAgentVersionStartDifferentVersion t = do bob <- getSMPAgentClient' 2 agentCfg {smpAgentVRange = mkVersionRange 1 3} initAgentServers testDB2 withSmpServerStoreMsgLogOn t testPort $ \_ -> do (aliceId, bobId) <- runRight $ do - (aliceId, bobId) <- makeConnection_ PQSupportOff alice bob + (aliceId, bobId) <- makeConnection_ PQSupportOff False alice bob exchangeGreetingsMsgId_ PQEncOff 2 alice bobId bob aliceId checkVersion alice bobId 2 checkVersion bob aliceId 2 @@ -1620,7 +1657,8 @@ testRatchetSyncSimultaneous t = do testOnlyCreatePullSlowHandshake :: IO () testOnlyCreatePullSlowHandshake = withAgentClientsCfg2 agentProxyCfgV8 agentProxyCfgV8 $ \alice bob -> runRight_ $ do (bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing SMOnlyCreate - aliceId <- joinConnection bob 1 True qInfo "bob's connInfo" SMOnlyCreate + (aliceId, sqSecured) <- joinConnection bob 1 True qInfo "bob's connInfo" SMOnlyCreate + liftIO $ sqSecured `shouldBe` False Just ("", _, CONF confId _ "bob's connInfo") <- getMsg alice bobId $ timeout 5_000000 $ get alice allowConnection alice bobId confId "alice's connInfo" liftIO $ threadDelay 1_000000 @@ -1654,7 +1692,8 @@ getMsg c cId action = do testOnlyCreatePull :: IO () testOnlyCreatePull = withAgentClients2 $ \alice bob -> runRight_ $ do (bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing SMOnlyCreate - aliceId <- joinConnection bob 1 True qInfo "bob's connInfo" SMOnlyCreate + (aliceId, sqSecured) <- joinConnection bob 1 True qInfo "bob's connInfo" SMOnlyCreate + liftIO $ sqSecured `shouldBe` True Just ("", _, CONF confId _ "bob's connInfo") <- getMsg alice bobId $ timeout 5_000000 $ get alice allowConnection alice bobId confId "alice's connInfo" liftIO $ threadDelay 1_000000 @@ -1676,20 +1715,22 @@ testOnlyCreatePull = withAgentClients2 $ \alice bob -> runRight_ $ do ackMessage alice bobId 3 Nothing makeConnection :: AgentClient -> AgentClient -> ExceptT AgentErrorType IO (ConnId, ConnId) -makeConnection = makeConnection_ PQSupportOn +makeConnection = makeConnection_ PQSupportOn True -makeConnection_ :: PQSupport -> AgentClient -> AgentClient -> ExceptT AgentErrorType IO (ConnId, ConnId) -makeConnection_ pqEnc alice bob = makeConnectionForUsers_ pqEnc alice 1 bob 1 +makeConnection_ :: PQSupport -> SndQueueSecured -> AgentClient -> AgentClient -> ExceptT AgentErrorType IO (ConnId, ConnId) +makeConnection_ pqEnc sqSecured alice bob = makeConnectionForUsers_ pqEnc sqSecured alice 1 bob 1 makeConnectionForUsers :: HasCallStack => AgentClient -> UserId -> AgentClient -> UserId -> ExceptT AgentErrorType IO (ConnId, ConnId) -makeConnectionForUsers = makeConnectionForUsers_ PQSupportOn +makeConnectionForUsers = makeConnectionForUsers_ PQSupportOn True -makeConnectionForUsers_ :: HasCallStack => PQSupport -> AgentClient -> UserId -> AgentClient -> UserId -> ExceptT AgentErrorType IO (ConnId, ConnId) -makeConnectionForUsers_ pqSupport alice aliceUserId bob bobUserId = do +makeConnectionForUsers_ :: HasCallStack => PQSupport -> SndQueueSecured -> AgentClient -> UserId -> AgentClient -> UserId -> ExceptT AgentErrorType IO (ConnId, ConnId) +makeConnectionForUsers_ pqSupport sqSecured alice aliceUserId bob bobUserId = do (bobId, qInfo) <- A.createConnection alice aliceUserId True SCMInvitation Nothing (CR.IKNoPQ pqSupport) SMSubscribe aliceId <- A.prepareConnectionToJoin bob bobUserId True qInfo pqSupport - aliceId' <- A.joinConnection bob bobUserId (Just aliceId) True qInfo "bob's connInfo" pqSupport SMSubscribe - liftIO $ aliceId' `shouldBe` aliceId + (aliceId', sqSecured') <- A.joinConnection bob bobUserId (Just aliceId) True qInfo "bob's connInfo" pqSupport SMSubscribe + liftIO $ do + aliceId' `shouldBe` aliceId + sqSecured' `shouldBe` sqSecured ("", _, A.CONF confId pqSup' _ "bob's connInfo") <- get alice liftIO $ pqSup' `shouldBe` pqSupport allowConnection alice bobId confId "alice's connInfo" @@ -1772,7 +1813,6 @@ testSuspendingAgentCompleteSending t = withAgentClients2 $ \a b -> do get b =##> \case ("", c, Msg "hello") -> c == aId; _ -> False ackMessage b aId 2 Nothing pure (aId, bId) - runRight_ $ do ("", "", DOWN {}) <- nGet a ("", "", DOWN {}) <- nGet b @@ -1780,15 +1820,17 @@ testSuspendingAgentCompleteSending t = withAgentClients2 $ \a b -> do 4 <- sendMessage b aId SMP.noMsgFlags "how are you?" liftIO $ threadDelay 100000 liftIO $ suspendAgent b 5000000 - withSmpServerStoreLogOn t testPort $ \_ -> runRight_ @AgentErrorType $ do - pGet b =##> \case ("", c, AEvt SAEConn (SENT 3)) -> c == aId; ("", "", AEvt _ UP {}) -> True; _ -> False - pGet b =##> \case ("", c, AEvt SAEConn (SENT 3)) -> c == aId; ("", "", AEvt _ UP {}) -> True; _ -> False - pGet b =##> \case ("", c, AEvt SAEConn (SENT 4)) -> c == aId; ("", "", AEvt _ UP {}) -> True; _ -> False - ("", "", SUSPENDED) <- nGet b - - pGet a =##> \case ("", c, AEvt _ (Msg "hello too")) -> c == bId; ("", "", AEvt _ UP {}) -> True; _ -> False - pGet a =##> \case ("", c, AEvt _ (Msg "hello too")) -> c == bId; ("", "", AEvt _ UP {}) -> True; _ -> False + -- there will be no UP event for b, because re-subscriptions are suspended until the agent is in foreground + get b =##> \case ("", c, SENT 3) -> c == aId; _ -> False + get b =##> \case ("", c, SENT 4) -> c == aId; _ -> False + nGet b ##> ("", "", SUSPENDED) + liftIO $ + getInAnyOrder + a + [ \case ("", c, AEvt _ (Msg "hello too")) -> c == bId; _ -> False, + \case ("", "", AEvt _ UP {}) -> True; _ -> False + ] ackMessage a bId 3 Nothing get a =##> \case ("", c, Msg "how are you?") -> c == bId; _ -> False ackMessage a bId 4 Nothing @@ -1816,7 +1858,7 @@ testBatchedSubscriptions :: Int -> Int -> ATransport -> IO () testBatchedSubscriptions nCreate nDel t = withAgentClientsCfgServers2 agentCfg agentCfg initAgentServers2 $ \a b -> do conns <- runServers $ do - conns <- replicateM nCreate $ makeConnection_ PQSupportOff a b + conns <- replicateM nCreate $ makeConnection_ PQSupportOff True a b forM_ conns $ \(aId, bId) -> exchangeGreetings_ PQEncOff a bId b aId let (aIds', bIds') = unzip $ take nDel conns delete a bIds' @@ -1894,15 +1936,59 @@ testBatchedPendingMessages nCreate nMsgs = withA = withAgent 1 agentCfg initAgentServers testDB withB = withAgent 2 agentCfg initAgentServers testDB2 -testAsyncCommands :: AgentClient -> AgentClient -> AgentMsgId -> IO () -testAsyncCommands alice bob baseId = +testSendMessagesB :: IO () +testSendMessagesB = withAgentClients2 $ \a b -> runRight_ $ do + (aId, bId) <- makeConnection a b + let msg cId body = Right (cId, PQEncOn, SMP.noMsgFlags, body) + [SentB 2, SentB 3, SentB 4] <- sendMessagesB a ([msg bId "msg 1", msg "" "msg 2", msg "" "msg 3"] :: [Either AgentErrorType MsgReq]) + get a ##> ("", bId, SENT 2) + get a ##> ("", bId, SENT 3) + get a ##> ("", bId, SENT 4) + receiveMsg b aId 2 "msg 1" + receiveMsg b aId 3 "msg 2" + receiveMsg b aId 4 "msg 3" + +testSendMessagesB2 :: IO () +testSendMessagesB2 = withAgentClients3 $ \a b c -> runRight_ $ do + (abId, bId) <- makeConnection a b + (acId, cId) <- makeConnection a c + let msg connId body = Right (connId, PQEncOn, SMP.noMsgFlags, body) + [SentB 2, SentB 3, SentB 4, SentB 2, SentB 3] <- + sendMessagesB a ([msg bId "msg 1", msg "" "msg 2", msg "" "msg 3", msg cId "msg 4", msg "" "msg 5"] :: [Either AgentErrorType MsgReq]) + liftIO $ + getInAnyOrder + a + [ \case ("", cId', AEvt SAEConn (SENT 2)) -> cId' == bId; _ -> False, + \case ("", cId', AEvt SAEConn (SENT 3)) -> cId' == bId; _ -> False, + \case ("", cId', AEvt SAEConn (SENT 4)) -> cId' == bId; _ -> False, + \case ("", cId', AEvt SAEConn (SENT 2)) -> cId' == cId; _ -> False, + \case ("", cId', AEvt SAEConn (SENT 3)) -> cId' == cId; _ -> False + ] + receiveMsg b abId 2 "msg 1" + receiveMsg b abId 3 "msg 2" + receiveMsg b abId 4 "msg 3" + receiveMsg c acId 2 "msg 4" + receiveMsg c acId 3 "msg 5" + +pattern SentB :: AgentMsgId -> Either AgentErrorType (AgentMsgId, PQEncryption) +pattern SentB msgId <- Right (msgId, PQEncOn) + +receiveMsg :: AgentClient -> ConnId -> AgentMsgId -> MsgBody -> ExceptT AgentErrorType IO () +receiveMsg c cId msgId msg = do + get c =##> \case ("", cId', Msg' mId' PQEncOn msg') -> cId' == cId && mId' == msgId && msg' == msg; _ -> False + ackMessage c cId msgId Nothing + +testAsyncCommands :: SndQueueSecured -> AgentClient -> AgentClient -> AgentMsgId -> IO () +testAsyncCommands sqSecured alice bob baseId = runRight_ $ do bobId <- createConnectionAsync alice 1 "1" True SCMInvitation (IKNoPQ PQSupportOn) SMSubscribe ("1", bobId', INV (ACR _ qInfo)) <- get alice liftIO $ bobId' `shouldBe` bobId aliceId <- joinConnectionAsync bob 1 "2" True qInfo "bob's connInfo" PQSupportOn SMSubscribe - ("2", aliceId', OK) <- get bob - liftIO $ aliceId' `shouldBe` aliceId + ("2", aliceId', JOINED sqSecured') <- get bob + liftIO $ do + aliceId' `shouldBe` aliceId + sqSecured' `shouldBe` sqSecured ("", _, CONF confId _ "bob's connInfo") <- get alice allowConnectionAsync alice "3" bobId confId "alice's connInfo" get alice =##> \case ("3", _, OK) -> True; _ -> False @@ -1955,14 +2041,15 @@ testAsyncCommandsRestore t = do get alice' =##> \case ("1", _, INV _) -> True; _ -> False pure () -testAcceptContactAsync :: AgentClient -> AgentClient -> AgentMsgId -> IO () -testAcceptContactAsync alice bob baseId = +testAcceptContactAsync :: SndQueueSecured -> AgentClient -> AgentClient -> AgentMsgId -> IO () +testAcceptContactAsync sqSecured alice bob baseId = runRight_ $ do (_, qInfo) <- createConnection alice 1 True SCMContact Nothing SMSubscribe - aliceId <- joinConnection bob 1 True qInfo "bob's connInfo" SMSubscribe + (aliceId, sqSecuredJoin) <- joinConnection bob 1 True qInfo "bob's connInfo" SMSubscribe + liftIO $ sqSecuredJoin `shouldBe` False -- joining via contact address connection ("", _, REQ invId _ "bob's connInfo") <- get alice bobId <- acceptContactAsync alice "1" True invId "alice's connInfo" PQSupportOn SMSubscribe - get alice =##> \case ("1", c, OK) -> c == bobId; _ -> False + get alice =##> \case ("1", c, JOINED sqSecured') -> c == bobId && sqSecured' == sqSecured; _ -> False ("", _, CONF confId _ "alice's connInfo") <- get bob allowConnection bob aliceId confId "bob's connInfo" get alice ##> ("", bobId, INFO "bob's connInfo") @@ -2238,7 +2325,7 @@ testJoinConnectionAsyncReplyErrorV8 t = do pure (aId, bId) nGet a =##> \case ("", "", DOWN _ [c]) -> c == bId; _ -> False withSmpServerOn t testPort2 $ do - get b =##> \case ("2", c, OK) -> c == aId; _ -> False + get b =##> \case ("2", c, JOINED sqSecured) -> c == aId && not sqSecured; _ -> False confId <- withSmpServerStoreLogOn t testPort $ \_ -> do pGet a >>= \case ("", "", AEvt _ (UP _ [_])) -> do @@ -2279,7 +2366,7 @@ testJoinConnectionAsyncReplyError t = do withSmpServerOn t testPort2 $ do confId <- withSmpServerStoreLogOn t testPort $ \_ -> do -- both servers need to be online for connection to progress because of SKEY - get b =##> \case ("2", c, OK) -> c == aId; _ -> False + get b =##> \case ("2", c, JOINED sqSecured) -> c == aId && sqSecured; _ -> False pGet a >>= \case ("", "", AEvt _ (UP _ [_])) -> do ("", _, CONF confId _ "bob's connInfo") <- get a @@ -2733,8 +2820,8 @@ testSwitch2ConnectionsAbort1 servers = do withB :: (AgentClient -> IO a) -> IO a withB = withAgent 2 agentCfg servers testDB2 -testCreateQueueAuth :: HasCallStack => VersionSMP -> (Maybe BasicAuth, VersionSMP) -> (Maybe BasicAuth, VersionSMP) -> AgentMsgId -> IO Int -testCreateQueueAuth srvVersion clnt1 clnt2 baseId = do +testCreateQueueAuth :: HasCallStack => VersionSMP -> (Maybe BasicAuth, VersionSMP) -> (Maybe BasicAuth, VersionSMP) -> SndQueueSecured -> AgentMsgId -> IO Int +testCreateQueueAuth srvVersion clnt1 clnt2 sqSecured baseId = do a <- getClient 1 clnt1 testDB b <- getClient 2 clnt2 testDB2 r <- runRight $ do @@ -2745,7 +2832,8 @@ testCreateQueueAuth srvVersion clnt1 clnt2 baseId = do tryError (joinConnection b 1 True qInfo "bob's connInfo" SMSubscribe) >>= \case Left (SMP _ AUTH) -> pure 1 Left e -> throwError e - Right aId -> do + Right (aId, sqSecured') -> do + liftIO $ sqSecured' `shouldBe` sqSecured ("", _, CONF confId _ "bob's connInfo") <- get a allowConnection a bId confId "alice's connInfo" get a ##> ("", bId, CON) @@ -2805,7 +2893,7 @@ testDeliveryReceiptsVersion t = do b <- getSMPAgentClient' 2 agentCfg {smpAgentVRange = mkVersionRange 1 3} initAgentServers testDB2 withSmpServerStoreMsgLogOn t testPort $ \_ -> do (aId, bId) <- runRight $ do - (aId, bId) <- makeConnection_ PQSupportOff a b + (aId, bId) <- makeConnection_ PQSupportOff False a b checkVersion a bId 3 checkVersion b aId 3 (2, _) <- A.sendMessage a bId PQEncOff SMP.noMsgFlags "hello" @@ -2829,8 +2917,8 @@ testDeliveryReceiptsVersion t = do subscribeConnection a' bId subscribeConnection b' aId exchangeGreetingsMsgId_ PQEncOff 4 a' bId b' aId - checkVersion a' bId 6 - checkVersion b' aId 6 + checkVersion a' bId 7 + checkVersion b' aId 7 (6, PQEncOff) <- A.sendMessage a' bId PQEncOn SMP.noMsgFlags "hello" get a' ##> ("", bId, SENT 6) get b' =##> \case ("", c, Msg' 6 PQEncOff "hello") -> c == aId; _ -> False @@ -2979,7 +3067,8 @@ testServerMultipleIdentities :: HasCallStack => IO () testServerMultipleIdentities = withAgentClients2 $ \alice bob -> runRight_ $ do (bobId, cReq) <- createConnection alice 1 True SCMInvitation Nothing SMSubscribe - aliceId <- joinConnection bob 1 True cReq "bob's connInfo" SMSubscribe + (aliceId, sqSecured) <- joinConnection bob 1 True cReq "bob's connInfo" SMSubscribe + liftIO $ sqSecured `shouldBe` True ("", _, CONF confId _ "bob's connInfo") <- get alice allowConnection alice bobId confId "alice's connInfo" get alice ##> ("", bobId, CON) @@ -3078,7 +3167,8 @@ testServerQueueInfo = do (bobId, cReq) <- createConnection alice 1 True SCMInvitation Nothing SMSubscribe liftIO $ threadDelay 200000 checkEmptyQ alice bobId False - aliceId <- joinConnection bob 1 True cReq "bob's connInfo" SMSubscribe + (aliceId, sqSecured) <- joinConnection bob 1 True cReq "bob's connInfo" SMSubscribe + liftIO $ sqSecured `shouldBe` True ("", _, CONF confId _ "bob's connInfo") <- get alice liftIO $ threadDelay 200000 checkEmptyQ alice bobId True -- secured by sender diff --git a/tests/AgentTests/NotificationTests.hs b/tests/AgentTests/NotificationTests.hs index fd737e913..cc79faeca 100644 --- a/tests/AgentTests/NotificationTests.hs +++ b/tests/AgentTests/NotificationTests.hs @@ -477,7 +477,7 @@ testNotificationSubscriptionExistingConnection APNSMockServer {apnsQ} baseId ali (bobId, aliceId, nonce, message) <- runRight $ do -- establish connection (bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing SMSubscribe - aliceId <- joinConnection bob 1 True qInfo "bob's connInfo" SMSubscribe + (aliceId, _sqSecured) <- joinConnection bob 1 True qInfo "bob's connInfo" SMSubscribe ("", _, CONF confId _ "bob's connInfo") <- get alice allowConnection alice bobId confId "alice's connInfo" get bob ##> ("", aliceId, INFO "alice's connInfo") @@ -511,7 +511,7 @@ testNotificationSubscriptionExistingConnection APNSMockServer {apnsQ} baseId ali -- aliceNtf client doesn't have subscription and is allowed to get notification message withAgent 3 aliceCfg initAgentServers testDB $ \aliceNtf -> runRight_ $ do - (_, [SMPMsgMeta {msgFlags = MsgFlags True}]) <- getNotificationMessage aliceNtf nonce message + (_, Just SMPMsgMeta {msgFlags = MsgFlags True}) <- getNotificationMessage aliceNtf nonce message pure () threadDelay 1000000 @@ -544,7 +544,7 @@ testNotificationSubscriptionNewConnection APNSMockServer {apnsQ} baseId alice bo liftIO $ threadDelay 50000 (bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing SMSubscribe liftIO $ threadDelay 1000000 - aliceId <- joinConnection bob 1 True qInfo "bob's connInfo" SMSubscribe + (aliceId, _sqSecured) <- joinConnection bob 1 True qInfo "bob's connInfo" SMSubscribe liftIO $ threadDelay 750000 void $ messageNotificationData alice apnsQ ("", _, CONF confId _ "bob's connInfo") <- get alice @@ -591,7 +591,8 @@ testChangeNotificationsMode APNSMockServer {apnsQ} = withAgentClients2 $ \alice bob -> runRight_ $ do -- establish connection (bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing SMSubscribe - aliceId <- joinConnection bob 1 True qInfo "bob's connInfo" SMSubscribe + (aliceId, sqSecured) <- joinConnection bob 1 True qInfo "bob's connInfo" SMSubscribe + liftIO $ sqSecured `shouldBe` True ("", _, CONF confId _ "bob's connInfo") <- get alice allowConnection alice bobId confId "alice's connInfo" get bob ##> ("", aliceId, INFO "alice's connInfo") @@ -653,7 +654,8 @@ testChangeToken APNSMockServer {apnsQ} = withAgent 1 agentCfg initAgentServers t (aliceId, bobId) <- withAgent 2 agentCfg initAgentServers testDB $ \alice -> runRight $ do -- establish connection (bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing SMSubscribe - aliceId <- joinConnection bob 1 True qInfo "bob's connInfo" SMSubscribe + (aliceId, sqSecured) <- joinConnection bob 1 True qInfo "bob's connInfo" SMSubscribe + liftIO $ sqSecured `shouldBe` True ("", _, CONF confId _ "bob's connInfo") <- get alice allowConnection alice bobId confId "alice's connInfo" get bob ##> ("", aliceId, INFO "alice's connInfo") diff --git a/tests/CoreTests/BatchingTests.hs b/tests/CoreTests/BatchingTests.hs index caab0637a..5f6beb034 100644 --- a/tests/CoreTests/BatchingTests.hs +++ b/tests/CoreTests/BatchingTests.hs @@ -261,7 +261,7 @@ testClientStub :: IO (ProtocolClient SMPVersion ErrorType BrokerMsg) testClientStub = do g <- C.newRandom sessId <- atomically $ C.randomBytes 32 g - atomically $ smpClientStub g sessId subModeSMPVersion Nothing + smpClientStub g sessId subModeSMPVersion Nothing clientStubV7 :: IO (ProtocolClient SMPVersion ErrorType BrokerMsg) clientStubV7 = do @@ -269,7 +269,7 @@ clientStubV7 = do sessId <- atomically $ C.randomBytes 32 g (rKey, _) <- atomically $ C.generateAuthKeyPair C.SX25519 g thAuth_ <- testTHandleAuth authCmdsSMPVersion g rKey - atomically $ smpClientStub g sessId authCmdsSMPVersion thAuth_ + smpClientStub g sessId authCmdsSMPVersion thAuth_ randomSUB :: ByteString -> IO (Either TransportError (Maybe TransmissionAuth, ByteString)) randomSUB = randomSUB_ C.SEd25519 subModeSMPVersion diff --git a/tests/CoreTests/RetryIntervalTests.hs b/tests/CoreTests/RetryIntervalTests.hs index 7097df989..da96d0208 100644 --- a/tests/CoreTests/RetryIntervalTests.hs +++ b/tests/CoreTests/RetryIntervalTests.hs @@ -2,6 +2,8 @@ module CoreTests.RetryIntervalTests where +import Control.Concurrent (threadDelay) +import Control.Concurrent.Async (concurrently_) import Control.Concurrent.STM import Control.Monad (when) import Data.Time.Clock (UTCTime, diffUTCTime, getCurrentTime, nominalDiffTimeToSeconds) @@ -13,6 +15,10 @@ retryIntervalTests = do describe "Retry interval with 2 modes and lock" $ do testRetryIntervalSameMode testRetryIntervalSwitchMode + describe "Foreground retry interval" $ do + testRetryForeground + testRetryToBackground + testRetrySkipWhenForeground testRI :: RetryInterval2 testRI = @@ -23,12 +29,15 @@ testRI = increaseAfter = 40000, maxInterval = 40000 }, - riFast = - RetryInterval - { initialInterval = 10000, - increaseAfter = 20000, - maxInterval = 40000 - } + riFast = testFastRI + } + +testFastRI :: RetryInterval +testFastRI = + RetryInterval + { initialInterval = 10000, + increaseAfter = 20000, + maxInterval = 40000 } testRetryIntervalSameMode :: Spec @@ -81,6 +90,67 @@ testRetryIntervalSwitchMode = (40000, 40000) ] +testRetryForeground :: Spec +testRetryForeground = + it "should increase elapased time and interval" $ do + intervals <- newTVarIO [] + reportedIntervals <- newTVarIO [] + ts <- newTVarIO =<< getCurrentTime + let isForeground = pure True + withRetryForeground testFastRI isForeground (pure True) $ \delay loop -> do + ints <- addInterval intervals ts + atomically $ modifyTVar' reportedIntervals (delay :) + when (length ints < 8) $ loop + (reverse <$> readTVarIO intervals) `shouldReturn` [0, 1, 1, 1, 2, 3, 4, 4] + (reverse <$> readTVarIO reportedIntervals) + `shouldReturn` [ 10000, 10000, 15000, 22500, 33750, 40000, 40000, 40000] + +testRetryToBackground :: Spec +testRetryToBackground = + it "should not change interval when moving to background" $ do + intervals <- newTVarIO [] + reportedIntervals <- newTVarIO [] + ts <- newTVarIO =<< getCurrentTime + foreground <- newTVarIO True + concurrently_ + ( do + threadDelay 50000 + atomically $ writeTVar foreground False + ) + ( withRetryForeground testFastRI (readTVar foreground) (pure True) $ \delay loop -> do + ints <- addInterval intervals ts + atomically $ modifyTVar' reportedIntervals (delay :) + when (length ints < 8) $ loop + ) + (reverse <$> readTVarIO intervals) `shouldReturn` [0, 1, 1, 1, 2, 3, 4, 4] + (reverse <$> readTVarIO reportedIntervals) + `shouldReturn` [ 10000, 10000, 15000, 22500, 33750, 40000, 40000, 40000] + +testRetrySkipWhenForeground :: Spec +testRetrySkipWhenForeground = + it "should repeat loop as soon as moving to foreground" $ do + intervals <- newTVarIO [] + reportedIntervals <- newTVarIO [] + ts <- newTVarIO =<< getCurrentTime + foreground <- newTVarIO False + concurrently_ + ( do + threadDelay 65000 + atomically $ writeTVar foreground True + threadDelay 10000 + atomically $ writeTVar foreground False + threadDelay 100000 + atomically $ writeTVar foreground True + ) + ( withRetryForeground testFastRI (readTVar foreground) (pure True) $ \delay loop -> do + ints <- addInterval intervals ts + atomically $ modifyTVar' reportedIntervals (delay :) + when (length ints < 12) $ loop + ) + (reverse <$> readTVarIO intervals) `shouldReturn` [0, 1, 1, 1, 2, 0, 1, 1, 1, 2, 3, 1] + (reverse <$> readTVarIO reportedIntervals) + `shouldReturn` [ 10000, 10000, 15000, 22500, 33750, 10000, 10000, 15000, 22500, 33750, 40000, 10000] + addInterval :: TVar [Int] -> TVar UTCTime -> IO [Int] addInterval intervals ts = do ts' <- getCurrentTime diff --git a/tests/CoreTests/TRcvQueuesTests.hs b/tests/CoreTests/TRcvQueuesTests.hs index 9f7c4932e..24d54fc8e 100644 --- a/tests/CoreTests/TRcvQueuesTests.hs +++ b/tests/CoreTests/TRcvQueuesTests.hs @@ -1,6 +1,7 @@ {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} module CoreTests.TRcvQueuesTests where @@ -30,19 +31,19 @@ tRcvQueuesTests = do describe "queue transfer" $ do it "getDelSessQueues-batchAddQueues preserves total length" removeSubsTest -checkDataInvariant :: RQ.TRcvQueues -> IO Bool +checkDataInvariant :: RQ.Queue q => RQ.TRcvQueues q -> IO Bool checkDataInvariant trq = atomically $ do conns <- readTVar $ RQ.getConnections trq qs <- readTVar $ RQ.getRcvQueues trq -- three invariant checks - let inv1 = all (\cId -> (S.fromList . L.toList <$> M.lookup cId conns) == Just (M.keysSet (M.filter (\q -> connId q == cId) qs))) (M.keys conns) - inv2 = all (\(k, q) -> maybe False ((k `elem`) . L.toList) (M.lookup (connId q) conns)) (M.assocs qs) + let inv1 = all (\cId -> (S.fromList . L.toList <$> M.lookup cId conns) == Just (M.keysSet (M.filter (\q -> RQ.connId' q == cId) qs))) (M.keys conns) + inv2 = all (\(k, q) -> maybe False ((k `elem`) . L.toList) (M.lookup (RQ.connId' q) conns)) (M.assocs qs) inv3 = all (\(k, q) -> RQ.qKey q == k) (M.assocs qs) pure $ inv1 && inv2 && inv3 hasConnTest :: IO () hasConnTest = do - trq <- atomically RQ.empty + trq <- RQ.empty atomically $ RQ.addQueue (dummyRQ 0 "smp://1234-w==@alpha" "c1") trq checkDataInvariant trq `shouldReturn` True atomically $ RQ.addQueue (dummyRQ 0 "smp://1234-w==@alpha" "c2") trq @@ -56,7 +57,7 @@ hasConnTest = do hasConnTestBatch :: IO () hasConnTestBatch = do - trq <- atomically RQ.empty + trq <- RQ.empty let qs = [dummyRQ 0 "smp://1234-w==@alpha" "c1", dummyRQ 0 "smp://1234-w==@alpha" "c2", dummyRQ 0 "smp://1234-w==@beta" "c3"] atomically $ RQ.batchAddQueues trq qs checkDataInvariant trq `shouldReturn` True @@ -67,7 +68,7 @@ hasConnTestBatch = do batchIdempotentTest :: IO () batchIdempotentTest = do - trq <- atomically RQ.empty + trq <- RQ.empty let qs = [dummyRQ 0 "smp://1234-w==@alpha" "c1", dummyRQ 0 "smp://1234-w==@alpha" "c2", dummyRQ 0 "smp://1234-w==@beta" "c3"] atomically $ RQ.batchAddQueues trq qs checkDataInvariant trq `shouldReturn` True @@ -76,11 +77,11 @@ batchIdempotentTest = do atomically $ RQ.batchAddQueues trq qs checkDataInvariant trq `shouldReturn` True readTVarIO (RQ.getRcvQueues trq) `shouldReturn` qs' - fmap L.nub <$> readTVarIO (RQ.getConnections trq) `shouldReturn`cs' -- connections get duplicated, but that doesn't appear to affect anybody + fmap L.nub <$> readTVarIO (RQ.getConnections trq) `shouldReturn` cs' -- connections get duplicated, but that doesn't appear to affect anybody deleteConnTest :: IO () deleteConnTest = do - trq <- atomically RQ.empty + trq <- RQ.empty atomically $ do RQ.addQueue (dummyRQ 0 "smp://1234-w==@alpha" "c1") trq RQ.addQueue (dummyRQ 0 "smp://1234-w==@alpha" "c2") trq @@ -94,7 +95,7 @@ deleteConnTest = do getSessQueuesTest :: IO () getSessQueuesTest = do - trq <- atomically RQ.empty + trq <- RQ.empty atomically $ RQ.addQueue (dummyRQ 0 "smp://1234-w==@alpha" "c1") trq checkDataInvariant trq `shouldReturn` True atomically $ RQ.addQueue (dummyRQ 0 "smp://1234-w==@alpha" "c2") trq @@ -103,32 +104,40 @@ getSessQueuesTest = do checkDataInvariant trq `shouldReturn` True atomically $ RQ.addQueue (dummyRQ 1 "smp://1234-w==@beta" "c4") trq checkDataInvariant trq `shouldReturn` True - atomically (RQ.getSessQueues (0, "smp://1234-w==@alpha", Just "c1") trq) `shouldReturn` [dummyRQ 0 "smp://1234-w==@alpha" "c1"] - atomically (RQ.getSessQueues (1, "smp://1234-w==@alpha", Just "c1") trq) `shouldReturn` [] - atomically (RQ.getSessQueues (0, "smp://1234-w==@alpha", Just "nope") trq) `shouldReturn` [] - atomically (RQ.getSessQueues (0, "smp://1234-w==@alpha", Nothing) trq) `shouldReturn` [dummyRQ 0 "smp://1234-w==@alpha" "c2", dummyRQ 0 "smp://1234-w==@alpha" "c1"] + let tSess1 = (0, "smp://1234-w==@alpha", Just "c1") + RQ.getSessQueues tSess1 trq `shouldReturn` [dummyRQ 0 "smp://1234-w==@alpha" "c1"] + atomically (RQ.hasSessQueues tSess1 trq) `shouldReturn` True + let tSess2 = (1, "smp://1234-w==@alpha", Just "c1") + RQ.getSessQueues tSess2 trq `shouldReturn` [] + atomically (RQ.hasSessQueues tSess2 trq) `shouldReturn` False + let tSess3 = (0, "smp://1234-w==@alpha", Just "nope") + RQ.getSessQueues tSess3 trq `shouldReturn` [] + atomically (RQ.hasSessQueues tSess3 trq) `shouldReturn` False + let tSess4 = (0, "smp://1234-w==@alpha", Nothing) + RQ.getSessQueues tSess4 trq `shouldReturn` [dummyRQ 0 "smp://1234-w==@alpha" "c2", dummyRQ 0 "smp://1234-w==@alpha" "c1"] + atomically (RQ.hasSessQueues tSess4 trq) `shouldReturn`True getDelSessQueuesTest :: IO () getDelSessQueuesTest = do - trq <- atomically RQ.empty + trq <- RQ.empty let qs = - [ dummyRQ 0 "smp://1234-w==@alpha" "c1", - dummyRQ 0 "smp://1234-w==@alpha" "c2", - dummyRQ 0 "smp://1234-w==@beta" "c3", - dummyRQ 1 "smp://1234-w==@beta" "c4" + [ ("1", dummyRQ 0 "smp://1234-w==@alpha" "c1"), + ("1", dummyRQ 0 "smp://1234-w==@alpha" "c2"), + ("1", dummyRQ 0 "smp://1234-w==@beta" "c3"), + ("1", dummyRQ 1 "smp://1234-w==@beta" "c4") ] atomically $ RQ.batchAddQueues trq qs checkDataInvariant trq `shouldReturn` True -- no user - atomically (RQ.getDelSessQueues (2, "smp://1234-w==@alpha", Nothing) trq) `shouldReturn` ([], []) + atomically (RQ.getDelSessQueues (2, "smp://1234-w==@alpha", Nothing) "1" trq) `shouldReturn` ([], []) checkDataInvariant trq `shouldReturn` True -- wrong user - atomically (RQ.getDelSessQueues (1, "smp://1234-w==@alpha", Nothing) trq) `shouldReturn` ([], []) + atomically (RQ.getDelSessQueues (1, "smp://1234-w==@alpha", Nothing) "1" trq) `shouldReturn` ([], []) checkDataInvariant trq `shouldReturn` True -- connections intact atomically (RQ.hasConn "c1" trq) `shouldReturn` True atomically (RQ.hasConn "c2" trq) `shouldReturn` True - atomically (RQ.getDelSessQueues (0, "smp://1234-w==@alpha", Nothing) trq) `shouldReturn` ([dummyRQ 0 "smp://1234-w==@alpha" "c2", dummyRQ 0 "smp://1234-w==@alpha" "c1"], ["c1", "c2"]) + atomically (RQ.getDelSessQueues (0, "smp://1234-w==@alpha", Nothing) "1" trq) `shouldReturn` ([dummyRQ 0 "smp://1234-w==@alpha" "c2", dummyRQ 0 "smp://1234-w==@alpha" "c1"], ["c1", "c2"]) checkDataInvariant trq `shouldReturn` True -- connections gone atomically (RQ.hasConn "c1" trq) `shouldReturn` False @@ -139,31 +148,31 @@ getDelSessQueuesTest = do removeSubsTest :: IO () removeSubsTest = do - aq <- atomically RQ.empty + aq <- RQ.empty let qs = - [ dummyRQ 0 "smp://1234-w==@alpha" "c1", - dummyRQ 0 "smp://1234-w==@alpha" "c2", - dummyRQ 0 "smp://1234-w==@beta" "c3", - dummyRQ 1 "smp://1234-w==@beta" "c4" + [ ("1", dummyRQ 0 "smp://1234-w==@alpha" "c1"), + ("1", dummyRQ 0 "smp://1234-w==@alpha" "c2"), + ("1", dummyRQ 0 "smp://1234-w==@beta" "c3"), + ("1", dummyRQ 1 "smp://1234-w==@beta" "c4") ] atomically $ RQ.batchAddQueues aq qs - pq <- atomically RQ.empty + pq <- RQ.empty atomically (totalSize aq pq) `shouldReturn` (4, 4) - atomically $ RQ.getDelSessQueues (0, "smp://1234-w==@alpha", Nothing) aq >>= RQ.batchAddQueues pq . fst + atomically $ RQ.getDelSessQueues (0, "smp://1234-w==@alpha", Nothing) "1" aq >>= RQ.batchAddQueues pq . map ("1",) . fst atomically (totalSize aq pq) `shouldReturn` (4, 4) - atomically $ RQ.getDelSessQueues (0, "smp://1234-w==@beta", Just "non-existent") aq >>= RQ.batchAddQueues pq . fst + atomically $ RQ.getDelSessQueues (0, "smp://1234-w==@beta", Just "non-existent") "1" aq >>= RQ.batchAddQueues pq . map ("1",) . fst atomically (totalSize aq pq) `shouldReturn` (4, 4) - atomically $ RQ.getDelSessQueues (0, "smp://1234-w==@localhost", Nothing) aq >>= RQ.batchAddQueues pq . fst + atomically $ RQ.getDelSessQueues (0, "smp://1234-w==@localhost", Nothing) "1" aq >>= RQ.batchAddQueues pq . map ("1",) . fst atomically (totalSize aq pq) `shouldReturn` (4, 4) - atomically $ RQ.getDelSessQueues (0, "smp://1234-w==@beta", Just "c3") aq >>= RQ.batchAddQueues pq . fst + atomically $ RQ.getDelSessQueues (0, "smp://1234-w==@beta", Just "c3") "1" aq >>= RQ.batchAddQueues pq . map ("1",) . fst atomically (totalSize aq pq) `shouldReturn` (4, 4) -totalSize :: RQ.TRcvQueues -> RQ.TRcvQueues -> STM (Int, Int) +totalSize :: RQ.TRcvQueues q -> RQ.TRcvQueues q -> STM (Int, Int) totalSize a b = do qsizeA <- M.size <$> readTVar (RQ.getRcvQueues a) qsizeB <- M.size <$> readTVar (RQ.getRcvQueues b) diff --git a/tests/SMPAgentClient.hs b/tests/SMPAgentClient.hs index e0de57466..0bb050cbe 100644 --- a/tests/SMPAgentClient.hs +++ b/tests/SMPAgentClient.hs @@ -72,8 +72,6 @@ agentCfg = ntfCfg = defaultNTFClientConfig {qSize = 1, defaultTransport = (ntfTestPort, transport @TLS), networkConfig}, reconnectInterval = fastRetryInterval, persistErrorInterval = 1, - ntfWorkerDelay = 100, - ntfSMPWorkerDelay = 100, caCertificateFile = "tests/fixtures/ca.crt", privateKeyFile = "tests/fixtures/server.key", certificateFile = "tests/fixtures/server.crt" diff --git a/tests/SMPProxyTests.hs b/tests/SMPProxyTests.hs index 7505ef977..8044d23f7 100644 --- a/tests/SMPProxyTests.hs +++ b/tests/SMPProxyTests.hs @@ -207,7 +207,8 @@ agentDeliverMessageViaProxy aTestCfg@(aSrvs, _, aViaProxy) bTestCfg@(bSrvs, _, b withAgent 1 aCfg (servers aTestCfg) testDB $ \alice -> withAgent 2 aCfg (servers bTestCfg) testDB2 $ \bob -> runRight_ $ do (bobId, qInfo) <- A.createConnection alice 1 True SCMInvitation Nothing (CR.IKNoPQ PQSupportOn) SMSubscribe - aliceId <- A.joinConnection bob 1 Nothing True qInfo "bob's connInfo" PQSupportOn SMSubscribe + (aliceId, sqSecured) <- A.joinConnection bob 1 Nothing True qInfo "bob's connInfo" PQSupportOn SMSubscribe + liftIO $ sqSecured `shouldBe` True ("", _, A.CONF confId pqSup' _ "bob's connInfo") <- get alice liftIO $ pqSup' `shouldBe` PQSupportOn allowConnection alice bobId confId "alice's connInfo" @@ -261,7 +262,8 @@ agentDeliverMessagesViaProxyConc agentServers msgs = -- otherwise the CONF messages would get mixed with MSG prePair alice bob = do (bobId, qInfo) <- runExceptT' $ A.createConnection alice 1 True SCMInvitation Nothing (CR.IKNoPQ PQSupportOn) SMSubscribe - aliceId <- runExceptT' $ A.joinConnection bob 1 Nothing True qInfo "bob's connInfo" PQSupportOn SMSubscribe + (aliceId, sqSecured) <- runExceptT' $ A.joinConnection bob 1 Nothing True qInfo "bob's connInfo" PQSupportOn SMSubscribe + liftIO $ sqSecured `shouldBe` True confId <- get alice >>= \case ("", _, A.CONF confId pqSup' _ "bob's connInfo") -> do @@ -329,7 +331,8 @@ agentViaProxyRetryOffline = do withServer $ \_ -> do (aliceId, bobId) <- withServer2 $ \_ -> runRight $ do (bobId, qInfo) <- A.createConnection alice 1 True SCMInvitation Nothing (CR.IKNoPQ PQSupportOn) SMSubscribe - aliceId <- A.joinConnection bob 1 Nothing True qInfo "bob's connInfo" PQSupportOn SMSubscribe + (aliceId, sqSecured) <- A.joinConnection bob 1 Nothing True qInfo "bob's connInfo" PQSupportOn SMSubscribe + liftIO $ sqSecured `shouldBe` True ("", _, A.CONF confId pqSup' _ "bob's connInfo") <- get alice liftIO $ pqSup' `shouldBe` PQSupportOn allowConnection alice bobId confId "alice's connInfo" @@ -358,11 +361,15 @@ agentViaProxyRetryOffline = do -- proxy relay down 4 <- msgId <$> A.sendMessage bob aliceId pqEnc noMsgFlags msg2 bob `down` aliceId - withServer2 $ \_ -> runRight_ $ do - bob `up` aliceId - get bob ##> ("", aliceId, A.SENT (baseId + 4) bProxySrv) - get alice =##> \case ("", c, Msg' _ pq msg2') -> c == bobId && pq == pqEnc && msg2 == msg2'; _ -> False - ackMessage alice bobId (baseId + 4) Nothing + withServer2 $ \_ -> do + getInAnyOrder + bob + [ \case ("", "", AEvt SAENone (UP _ [c])) -> c == aliceId; _ -> False, + \case ("", c, AEvt SAEConn (A.SENT mId srv)) -> c == aliceId && mId == baseId + 4 && srv == bProxySrv; _ -> False + ] + runRight_ $ do + get alice =##> \case ("", c, Msg' _ pq msg2') -> c == bobId && pq == pqEnc && msg2 == msg2'; _ -> False + ackMessage alice bobId (baseId + 4) Nothing where withServer :: (ThreadId -> IO a) -> IO a withServer = withServer_ testStoreLogFile testStoreMsgsFile testPort diff --git a/tests/ServerTests.hs b/tests/ServerTests.hs index 10516b9f2..60aa1dd1c 100644 --- a/tests/ServerTests.hs +++ b/tests/ServerTests.hs @@ -610,7 +610,7 @@ testRestoreMessages at@(ATransport t) = logSize testStoreLogFile `shouldReturn` 2 logSize testStoreMsgsFile `shouldReturn` 5 - logSize testServerStatsBackupFile `shouldReturn` 55 + logSize testServerStatsBackupFile `shouldReturn` 71 Right stats1 <- strDecode <$> B.readFile testServerStatsBackupFile checkStats stats1 [rId] 5 1 @@ -628,7 +628,7 @@ testRestoreMessages at@(ATransport t) = logSize testStoreLogFile `shouldReturn` 1 -- the last message is not removed because it was not ACK'd logSize testStoreMsgsFile `shouldReturn` 3 - logSize testServerStatsBackupFile `shouldReturn` 55 + logSize testServerStatsBackupFile `shouldReturn` 71 Right stats2 <- strDecode <$> B.readFile testServerStatsBackupFile checkStats stats2 [rId] 5 3 @@ -647,7 +647,7 @@ testRestoreMessages at@(ATransport t) = logSize testStoreLogFile `shouldReturn` 1 logSize testStoreMsgsFile `shouldReturn` 0 - logSize testServerStatsBackupFile `shouldReturn` 55 + logSize testServerStatsBackupFile `shouldReturn` 71 Right stats3 <- strDecode <$> B.readFile testServerStatsBackupFile checkStats stats3 [rId] 5 5