Skip to content

Commit

Permalink
Implement ping-pong implementation that handles stale connections.
Browse files Browse the repository at this point in the history
Because the Server used only pong implementation it's not possible
to keep backwards compatibility with the new implementation that
also does pinging.
  • Loading branch information
domenkozar committed Dec 23, 2023
1 parent e0ac03d commit f22f55e
Show file tree
Hide file tree
Showing 8 changed files with 96 additions and 57 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# CHANGELOG

- 0.13.0.0 (xxx)
* Introduce `Network.WebSockets.Connection.PingPong` to
handle ping pong for any Connection, be it Client or Server.
* Remove `serverRequirePong` option in favor of the new implementation.

- 0.12.7.3 (2021-10-26)
* Bump `attoparsec` dependency upper bound to 0.15

Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ server and client in Haskell.
## Features

- Provides Server/Client implementations of the websocket protocol
- Ping/Pong building blocks for stale connection checking
- `withPingPong` helper for stale connection checking
- TLS support via [wuss](https://hackage.haskell.org/package/wuss) package

## Caveats
Expand Down
4 changes: 4 additions & 0 deletions src/Network/WebSockets.hs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ module Network.WebSockets
, newClientConnection

-- * Utilities
, PingPongOptions(..)
, defaultPingPongOptions
, withPingPong
, withPingThread
, forkPingThread
) where
Expand All @@ -91,6 +94,7 @@ module Network.WebSockets
--------------------------------------------------------------------------------
import Network.WebSockets.Client
import Network.WebSockets.Connection
import Network.WebSockets.Connection.PingPong
import Network.WebSockets.Http
import Network.WebSockets.Server
import Network.WebSockets.Types
3 changes: 3 additions & 0 deletions src/Network/WebSockets/Client.hs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ module Network.WebSockets.Client
--------------------------------------------------------------------------------
import qualified Data.ByteString.Builder as Builder
import Control.Exception (bracket, finally, throwIO)
import Control.Concurrent.MVar (newEmptyMVar)
import Control.Monad (void)
import Data.IORef (newIORef)
import qualified Data.Text as T
Expand Down Expand Up @@ -157,12 +158,14 @@ streamToClientConnection stream opts = do
(connectionMessageDataSizeLimit opts) stream
write <- encodeMessages protocol ClientConnection stream
sentRef <- newIORef False
heartbeat <- newEmptyMVar
return $ Connection
{ connectionOptions = opts
, connectionType = ClientConnection
, connectionProtocol = protocol
, connectionParse = parse
, connectionWrite = write
, connectionHeartbeat = heartbeat
, connectionSentClose = sentRef
}
where
Expand Down
13 changes: 10 additions & 3 deletions src/Network/WebSockets/Connection.hs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ import Control.Applicative ((<$>))
import Control.Concurrent (forkIO,
threadDelay)
import qualified Control.Concurrent.Async as Async
import Control.Concurrent.MVar (MVar, newEmptyMVar, tryPutMVar)
import Control.Exception (AsyncException,
fromException,
handle,
Expand Down Expand Up @@ -179,13 +180,15 @@ acceptRequestWith pc ar = case find (flip compatible request) protocols of
write <- foldM (\x ext -> extWrite ext x) writeRaw exts
parse <- foldM (\x ext -> extParse ext x) parseRaw exts

sentRef <- newIORef False
sentRef <- newIORef False
heartbeat <- newEmptyMVar
let connection = Connection
{ connectionOptions = options
, connectionType = ServerConnection
, connectionProtocol = protocol
, connectionParse = parse
, connectionWrite = write
, connectionHeartbeat = heartbeat
, connectionSentClose = sentRef
}

Expand Down Expand Up @@ -252,6 +255,7 @@ data Connection = Connection
{ connectionOptions :: !ConnectionOptions
, connectionType :: !ConnectionType
, connectionProtocol :: !Protocol
, connectionHeartbeat :: !(MVar ())
, connectionParse :: !(IO (Maybe Message))
, connectionWrite :: !([Message] -> IO ())
, connectionSentClose :: !(IORef Bool)
Expand Down Expand Up @@ -294,6 +298,7 @@ receiveDataMessage conn = do
unless hasSentClose $ send conn msg
throwIO $ CloseRequest i closeMsg
Pong _ -> do
_ <- tryPutMVar (connectionHeartbeat conn) ()
connectionOnPong (connectionOptions conn)
receiveDataMessage conn
Ping pl -> do
Expand Down Expand Up @@ -401,6 +406,9 @@ sendPong conn = send conn . ControlMessage . Pong . toLazyByteString
-- This is useful to keep idle connections open through proxies and whatnot.
-- Many (but not all) proxies have a 60 second default timeout, so based on that
-- sending a ping every 30 seconds is a good idea.
--
-- Note that usually you want to use 'Network.WebSockets.Connection.PingPong.withPingPong'
-- to timeout the connection if a pong is not received.
withPingThread
:: Connection
-> Int -- ^ Second interval in which pings should be sent.
Expand All @@ -410,7 +418,6 @@ withPingThread
withPingThread conn n action app =
Async.withAsync (pingThread conn n action) (\_ -> app)


--------------------------------------------------------------------------------
-- | DEPRECATED: Use 'withPingThread' instead.
--
Expand Down Expand Up @@ -445,4 +452,4 @@ pingThread conn n action

ignore e = case fromException e of
Just async -> throwIO (async :: AsyncException)
Nothing -> return ()
Nothing -> return ()
62 changes: 62 additions & 0 deletions src/Network/WebSockets/Connection/PingPong.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
module Network.WebSockets.Connection.PingPong
( withPingPong
, PingPongOptions(..)
, PongTimeout(..)
, defaultPingPongOptions
) where

import Control.Concurrent.Async as Async
import Control.Exception
import Control.Monad (void)
import Network.WebSockets.Connection (Connection, connectionHeartbeat, pingThread)
import Control.Concurrent.MVar (takeMVar)
import System.Timeout (timeout)


-- | Exception type used to kill connections if there
-- is a pong timeout.
data PongTimeout = PongTimeout deriving Show

instance Exception PongTimeout


-- | Options for ping-pong
--
-- Make sure that the ping interval is less than the pong timeout,
-- for example N/2.
data PingPongOptions = PingPongOptions {
pingInterval :: Int, -- ^ Interval in seconds
pongTimeout :: Int, -- ^ Timeout in seconds
pingAction :: IO () -- ^ Action to perform after sending a ping
}

-- | Default options for ping-pong
--
-- Ping every 15 seconds, timeout after 30 seconds
defaultPingPongOptions :: PingPongOptions
defaultPingPongOptions = PingPongOptions {
pingInterval = 15,
pongTimeout = 30,
pingAction = return ()
}

-- | Run an application with ping-pong enabled. Raises PongTimeout if a pong is not received.
--
-- Can used with Client and Server connections.
withPingPong :: PingPongOptions -> Connection -> (Connection -> IO ()) -> IO ()
withPingPong options connection app = void $
withAsync (app connection) $ \appAsync -> do
withAsync (pingThread connection (pingInterval options) (pingAction options)) $ \pingAsync -> do
withAsync (heartbeat >> throwIO PongTimeout) $ \heartbeatAsync -> do
waitAnyCancel [appAsync, pingAsync, heartbeatAsync]
where
heartbeat = whileJust $ timeout (pongTimeout options * 1000 * 1000)
$ takeMVar (connectionHeartbeat connection)

-- Loop until action returns Nothing
whileJust :: IO (Maybe a) -> IO ()
whileJust action = do
result <- action
case result of
Nothing -> return ()
Just _ -> whileJust action
62 changes: 9 additions & 53 deletions src/Network/WebSockets/Server.hs
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,17 @@ module Network.WebSockets.Server


--------------------------------------------------------------------------------
import Control.Concurrent (takeMVar, tryPutMVar,
newEmptyMVar)
import qualified Control.Concurrent.Async as Async
import Control.Exception (Exception, bracket,
import Control.Exception (bracket,
bracketOnError, finally, mask_,
throwIO)
import Network.Socket (Socket)
import qualified Network.Socket as S
import System.Timeout (timeout)


--------------------------------------------------------------------------------
import Network.WebSockets.Connection
import Network.WebSockets.Connection.PingPong (PongTimeout(..))
import Network.WebSockets.Http
import qualified Network.WebSockets.Stream as Stream
import Network.WebSockets.Types
Expand Down Expand Up @@ -83,10 +81,6 @@ data ServerOptions = ServerOptions
{ serverHost :: String
, serverPort :: Int
, serverConnectionOptions :: ConnectionOptions
-- | Require a pong from the client every N seconds; otherwise kill the
-- connection. If you use this, you should also use 'withPingThread' to
-- send a ping at a smaller interval; for example N/2.
, serverRequirePong :: Maybe Int
}


Expand All @@ -96,7 +90,6 @@ defaultServerOptions = ServerOptions
{ serverHost = "127.0.0.1"
, serverPort = 8080
, serverConnectionOptions = defaultConnectionOptions
, serverRequirePong = Nothing
}


Expand All @@ -109,43 +102,16 @@ runServerWithOptions :: ServerOptions -> ServerApp -> IO a
runServerWithOptions opts app = S.withSocketsDo $
bracket
(makeListenSocket (serverHost opts) (serverPort opts))
S.close $ \sock -> do
let connOpts = serverConnectionOptions opts

connThread conn = case serverRequirePong opts of
Nothing -> runApp conn connOpts app
Just grace -> do
heartbeat <- newEmptyMVar

let -- Update the connection options to perform a heartbeat
-- whenever a pong is received.
connOpts' = connOpts
{ connectionOnPong = do
_ <- tryPutMVar heartbeat ()
connectionOnPong connOpts
}

whileJust io = do
result <- io
case result of
Nothing -> return ()
Just _ -> whileJust io

-- Runs until a pong was not received within the grace
-- period.
heart = whileJust $ timeout (grace * 1000000) (takeMVar heartbeat)

Async.race_
(runApp conn connOpts' app)
(heart >> throwIO PongTimeout)

S.close
(\sock ->
let
mainThread = do
(conn, _) <- S.accept sock
Async.withAsyncWithUnmask
(\unmask -> unmask (connThread conn) `finally` S.close conn)
(\unmask -> unmask (runApp conn (serverConnectionOptions opts) app) `finally` S.close conn)
(\_ -> mainThread)

mask_ mainThread
in mask_ mainThread
)


--------------------------------------------------------------------------------
Expand Down Expand Up @@ -205,14 +171,4 @@ makePendingConnectionFromStream stream opts = do
, pendingRequest = request
, pendingOnAccept = \_ -> return ()
, pendingStream = stream
}


--------------------------------------------------------------------------------
-- | Internally used exception type used to kill connections if there
-- is a pong timeout.
data PongTimeout = PongTimeout deriving Show


--------------------------------------------------------------------------------
instance Exception PongTimeout
}
2 changes: 2 additions & 0 deletions websockets.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ Library
Network.WebSockets
Network.WebSockets.Client
Network.WebSockets.Connection
Network.WebSockets.Connection.PingPong
Network.WebSockets.Extensions
Network.WebSockets.Stream
-- Network.WebSockets.Util.PubSub TODO
Expand Down Expand Up @@ -108,6 +109,7 @@ Test-suite websockets-tests
Network.WebSockets.Client
Network.WebSockets.Connection
Network.WebSockets.Connection.Options
Network.WebSockets.Connection.PingPong
Network.WebSockets.Extensions
Network.WebSockets.Extensions.Description
Network.WebSockets.Extensions.PermessageDeflate
Expand Down

0 comments on commit f22f55e

Please sign in to comment.