From f22f55ec94d8b3c2285865028aeeac268403fa36 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Domen=20Ko=C5=BEar?= Date: Sat, 23 Dec 2023 13:10:52 +0000 Subject: [PATCH] Implement ping-pong implementation that handles stale connections. Because the Server used only pong implementation it's not possible to keep backwards compatibility with the new implementation that also does pinging. --- CHANGELOG | 5 ++ README.md | 2 +- src/Network/WebSockets.hs | 4 ++ src/Network/WebSockets/Client.hs | 3 + src/Network/WebSockets/Connection.hs | 13 +++- src/Network/WebSockets/Connection/PingPong.hs | 62 +++++++++++++++++++ src/Network/WebSockets/Server.hs | 62 +++---------------- websockets.cabal | 2 + 8 files changed, 96 insertions(+), 57 deletions(-) create mode 100644 src/Network/WebSockets/Connection/PingPong.hs diff --git a/CHANGELOG b/CHANGELOG index 6c749bb..67ce9bd 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -1,5 +1,10 @@ # CHANGELOG +- 0.13.0.0 (xxx) + * Introduce `Network.WebSockets.Connection.PingPong` to + handle ping pong for any Connection, be it Client or Server. + * Remove `serverRequirePong` option in favor of the new implementation. + - 0.12.7.3 (2021-10-26) * Bump `attoparsec` dependency upper bound to 0.15 diff --git a/README.md b/README.md index 0cd806a..2aeae0a 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ server and client in Haskell. ## Features - Provides Server/Client implementations of the websocket protocol -- Ping/Pong building blocks for stale connection checking +- `withPingPong` helper for stale connection checking - TLS support via [wuss](https://hackage.haskell.org/package/wuss) package ## Caveats diff --git a/src/Network/WebSockets.hs b/src/Network/WebSockets.hs index 1a12527..46c2e5e 100644 --- a/src/Network/WebSockets.hs +++ b/src/Network/WebSockets.hs @@ -83,6 +83,9 @@ module Network.WebSockets , newClientConnection -- * Utilities + , PingPongOptions(..) + , defaultPingPongOptions + , withPingPong , withPingThread , forkPingThread ) where @@ -91,6 +94,7 @@ module Network.WebSockets -------------------------------------------------------------------------------- import Network.WebSockets.Client import Network.WebSockets.Connection +import Network.WebSockets.Connection.PingPong import Network.WebSockets.Http import Network.WebSockets.Server import Network.WebSockets.Types diff --git a/src/Network/WebSockets/Client.hs b/src/Network/WebSockets/Client.hs index 28880c3..363ae7e 100644 --- a/src/Network/WebSockets/Client.hs +++ b/src/Network/WebSockets/Client.hs @@ -20,6 +20,7 @@ module Network.WebSockets.Client -------------------------------------------------------------------------------- import qualified Data.ByteString.Builder as Builder import Control.Exception (bracket, finally, throwIO) +import Control.Concurrent.MVar (newEmptyMVar) import Control.Monad (void) import Data.IORef (newIORef) import qualified Data.Text as T @@ -157,12 +158,14 @@ streamToClientConnection stream opts = do (connectionMessageDataSizeLimit opts) stream write <- encodeMessages protocol ClientConnection stream sentRef <- newIORef False + heartbeat <- newEmptyMVar return $ Connection { connectionOptions = opts , connectionType = ClientConnection , connectionProtocol = protocol , connectionParse = parse , connectionWrite = write + , connectionHeartbeat = heartbeat , connectionSentClose = sentRef } where diff --git a/src/Network/WebSockets/Connection.hs b/src/Network/WebSockets/Connection.hs index ec637b9..77ee60b 100644 --- a/src/Network/WebSockets/Connection.hs +++ b/src/Network/WebSockets/Connection.hs @@ -50,6 +50,7 @@ import Control.Applicative ((<$>)) import Control.Concurrent (forkIO, threadDelay) import qualified Control.Concurrent.Async as Async +import Control.Concurrent.MVar (MVar, newEmptyMVar, tryPutMVar) import Control.Exception (AsyncException, fromException, handle, @@ -179,13 +180,15 @@ acceptRequestWith pc ar = case find (flip compatible request) protocols of write <- foldM (\x ext -> extWrite ext x) writeRaw exts parse <- foldM (\x ext -> extParse ext x) parseRaw exts - sentRef <- newIORef False + sentRef <- newIORef False + heartbeat <- newEmptyMVar let connection = Connection { connectionOptions = options , connectionType = ServerConnection , connectionProtocol = protocol , connectionParse = parse , connectionWrite = write + , connectionHeartbeat = heartbeat , connectionSentClose = sentRef } @@ -252,6 +255,7 @@ data Connection = Connection { connectionOptions :: !ConnectionOptions , connectionType :: !ConnectionType , connectionProtocol :: !Protocol + , connectionHeartbeat :: !(MVar ()) , connectionParse :: !(IO (Maybe Message)) , connectionWrite :: !([Message] -> IO ()) , connectionSentClose :: !(IORef Bool) @@ -294,6 +298,7 @@ receiveDataMessage conn = do unless hasSentClose $ send conn msg throwIO $ CloseRequest i closeMsg Pong _ -> do + _ <- tryPutMVar (connectionHeartbeat conn) () connectionOnPong (connectionOptions conn) receiveDataMessage conn Ping pl -> do @@ -401,6 +406,9 @@ sendPong conn = send conn . ControlMessage . Pong . toLazyByteString -- This is useful to keep idle connections open through proxies and whatnot. -- Many (but not all) proxies have a 60 second default timeout, so based on that -- sending a ping every 30 seconds is a good idea. +-- +-- Note that usually you want to use 'Network.WebSockets.Connection.PingPong.withPingPong' +-- to timeout the connection if a pong is not received. withPingThread :: Connection -> Int -- ^ Second interval in which pings should be sent. @@ -410,7 +418,6 @@ withPingThread withPingThread conn n action app = Async.withAsync (pingThread conn n action) (\_ -> app) - -------------------------------------------------------------------------------- -- | DEPRECATED: Use 'withPingThread' instead. -- @@ -445,4 +452,4 @@ pingThread conn n action ignore e = case fromException e of Just async -> throwIO (async :: AsyncException) - Nothing -> return () + Nothing -> return () \ No newline at end of file diff --git a/src/Network/WebSockets/Connection/PingPong.hs b/src/Network/WebSockets/Connection/PingPong.hs new file mode 100644 index 0000000..cf3ae96 --- /dev/null +++ b/src/Network/WebSockets/Connection/PingPong.hs @@ -0,0 +1,62 @@ +module Network.WebSockets.Connection.PingPong + ( withPingPong + , PingPongOptions(..) + , PongTimeout(..) + , defaultPingPongOptions + ) where + +import Control.Concurrent.Async as Async +import Control.Exception +import Control.Monad (void) +import Network.WebSockets.Connection (Connection, connectionHeartbeat, pingThread) +import Control.Concurrent.MVar (takeMVar) +import System.Timeout (timeout) + + +-- | Exception type used to kill connections if there +-- is a pong timeout. +data PongTimeout = PongTimeout deriving Show + +instance Exception PongTimeout + + +-- | Options for ping-pong +-- +-- Make sure that the ping interval is less than the pong timeout, +-- for example N/2. +data PingPongOptions = PingPongOptions { + pingInterval :: Int, -- ^ Interval in seconds + pongTimeout :: Int, -- ^ Timeout in seconds + pingAction :: IO () -- ^ Action to perform after sending a ping +} + +-- | Default options for ping-pong +-- +-- Ping every 15 seconds, timeout after 30 seconds +defaultPingPongOptions :: PingPongOptions +defaultPingPongOptions = PingPongOptions { + pingInterval = 15, + pongTimeout = 30, + pingAction = return () +} + +-- | Run an application with ping-pong enabled. Raises PongTimeout if a pong is not received. +-- +-- Can used with Client and Server connections. +withPingPong :: PingPongOptions -> Connection -> (Connection -> IO ()) -> IO () +withPingPong options connection app = void $ + withAsync (app connection) $ \appAsync -> do + withAsync (pingThread connection (pingInterval options) (pingAction options)) $ \pingAsync -> do + withAsync (heartbeat >> throwIO PongTimeout) $ \heartbeatAsync -> do + waitAnyCancel [appAsync, pingAsync, heartbeatAsync] + where + heartbeat = whileJust $ timeout (pongTimeout options * 1000 * 1000) + $ takeMVar (connectionHeartbeat connection) + + -- Loop until action returns Nothing + whileJust :: IO (Maybe a) -> IO () + whileJust action = do + result <- action + case result of + Nothing -> return () + Just _ -> whileJust action \ No newline at end of file diff --git a/src/Network/WebSockets/Server.hs b/src/Network/WebSockets/Server.hs index 827ed37..184aace 100644 --- a/src/Network/WebSockets/Server.hs +++ b/src/Network/WebSockets/Server.hs @@ -19,19 +19,17 @@ module Network.WebSockets.Server -------------------------------------------------------------------------------- -import Control.Concurrent (takeMVar, tryPutMVar, - newEmptyMVar) import qualified Control.Concurrent.Async as Async -import Control.Exception (Exception, bracket, +import Control.Exception (bracket, bracketOnError, finally, mask_, throwIO) import Network.Socket (Socket) import qualified Network.Socket as S -import System.Timeout (timeout) -------------------------------------------------------------------------------- import Network.WebSockets.Connection +import Network.WebSockets.Connection.PingPong (PongTimeout(..)) import Network.WebSockets.Http import qualified Network.WebSockets.Stream as Stream import Network.WebSockets.Types @@ -83,10 +81,6 @@ data ServerOptions = ServerOptions { serverHost :: String , serverPort :: Int , serverConnectionOptions :: ConnectionOptions - -- | Require a pong from the client every N seconds; otherwise kill the - -- connection. If you use this, you should also use 'withPingThread' to - -- send a ping at a smaller interval; for example N/2. - , serverRequirePong :: Maybe Int } @@ -96,7 +90,6 @@ defaultServerOptions = ServerOptions { serverHost = "127.0.0.1" , serverPort = 8080 , serverConnectionOptions = defaultConnectionOptions - , serverRequirePong = Nothing } @@ -109,43 +102,16 @@ runServerWithOptions :: ServerOptions -> ServerApp -> IO a runServerWithOptions opts app = S.withSocketsDo $ bracket (makeListenSocket (serverHost opts) (serverPort opts)) - S.close $ \sock -> do - let connOpts = serverConnectionOptions opts - - connThread conn = case serverRequirePong opts of - Nothing -> runApp conn connOpts app - Just grace -> do - heartbeat <- newEmptyMVar - - let -- Update the connection options to perform a heartbeat - -- whenever a pong is received. - connOpts' = connOpts - { connectionOnPong = do - _ <- tryPutMVar heartbeat () - connectionOnPong connOpts - } - - whileJust io = do - result <- io - case result of - Nothing -> return () - Just _ -> whileJust io - - -- Runs until a pong was not received within the grace - -- period. - heart = whileJust $ timeout (grace * 1000000) (takeMVar heartbeat) - - Async.race_ - (runApp conn connOpts' app) - (heart >> throwIO PongTimeout) - + S.close + (\sock -> + let mainThread = do (conn, _) <- S.accept sock Async.withAsyncWithUnmask - (\unmask -> unmask (connThread conn) `finally` S.close conn) + (\unmask -> unmask (runApp conn (serverConnectionOptions opts) app) `finally` S.close conn) (\_ -> mainThread) - - mask_ mainThread + in mask_ mainThread + ) -------------------------------------------------------------------------------- @@ -205,14 +171,4 @@ makePendingConnectionFromStream stream opts = do , pendingRequest = request , pendingOnAccept = \_ -> return () , pendingStream = stream - } - - --------------------------------------------------------------------------------- --- | Internally used exception type used to kill connections if there --- is a pong timeout. -data PongTimeout = PongTimeout deriving Show - - --------------------------------------------------------------------------------- -instance Exception PongTimeout + } \ No newline at end of file diff --git a/websockets.cabal b/websockets.cabal index 804c46a..1917bf4 100644 --- a/websockets.cabal +++ b/websockets.cabal @@ -62,6 +62,7 @@ Library Network.WebSockets Network.WebSockets.Client Network.WebSockets.Connection + Network.WebSockets.Connection.PingPong Network.WebSockets.Extensions Network.WebSockets.Stream -- Network.WebSockets.Util.PubSub TODO @@ -108,6 +109,7 @@ Test-suite websockets-tests Network.WebSockets.Client Network.WebSockets.Connection Network.WebSockets.Connection.Options + Network.WebSockets.Connection.PingPong Network.WebSockets.Extensions Network.WebSockets.Extensions.Description Network.WebSockets.Extensions.PermessageDeflate