Skip to content

Commit

Permalink
Kill worker on PeerClosed
Browse files Browse the repository at this point in the history
  • Loading branch information
bgamari committed Aug 24, 2023
1 parent bf60ed2 commit b58f7b5
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 7 deletions.
3 changes: 3 additions & 0 deletions warp/Network/Wai/Handler/Warp.hs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ module Network.Wai.Handler.Warp (
, openFreePort
-- * Version
, warpVersion
-- * Handling premature connection closure
, connectionIsInactive
-- * HTTP/2
-- ** HTTP2 data
, HTTP2Data
Expand Down Expand Up @@ -141,6 +143,7 @@ import Network.Wai (Request, Response, vault)
import System.TimeManager

import Network.Wai.Handler.Warp.FileInfoCache
import Network.Wai.Handler.Warp.HTTP1 (connectionIsInactive)
import Network.Wai.Handler.Warp.HTTP2.Request (getHTTP2Data, setHTTP2Data, modifyHTTP2Data)
import Network.Wai.Handler.Warp.HTTP2.Types
import Network.Wai.Handler.Warp.Imports
Expand Down
34 changes: 28 additions & 6 deletions warp/Network/Wai/Handler/Warp/HTTP1.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,18 @@
{-# LANGUAGE ScopedTypeVariables #-}

module Network.Wai.Handler.Warp.HTTP1 (
http1
http1,
connectionIsInactive
) where

import "iproute" Data.IP (toHostAddress, toHostAddress6)
import qualified Control.Concurrent as Conc (yield)
import qualified Control.Concurrent as Conc
import qualified UnliftIO
import UnliftIO (SomeException, fromException, throwIO)
import qualified Data.ByteString as BS
import Data.Char (chr)
import Data.IORef (IORef, newIORef, readIORef, writeIORef)
import qualified Data.Vault.Lazy as Vault
import Network.Socket (SockAddr(SockAddrInet, SockAddrInet6))
import Network.Wai
import Network.Wai.Internal (ResponseReceived (ResponseReceived))
Expand All @@ -30,11 +32,21 @@ import Network.Wai.Handler.Warp.Types

http1 :: Settings -> InternalInfo -> Connection -> Transport -> Application -> SockAddr -> T.Handle -> ByteString -> IO ()
http1 settings ii conn transport app origAddr th bs0 = do
connActive <- mkConnActiveFlag
case connRegisterPeerClosedCb conn of
-- TODO Ignore only operation-not-supported exceptions
Just registerCb -> void $ UnliftIO.tryIO $ do
tid <- Conc.myThreadId
registerCb $ do
waitUntilConnInactive connActive
UnliftIO.throwTo tid PeerClosedException
Nothing -> return ()

istatus <- newIORef True
src <- mkSource (wrappedRecv conn istatus (settingsSlowlorisSize settings))
leftoverSource src bs0
addr <- getProxyProtocolAddr src
http1server settings ii conn transport app addr th istatus src
http1server settings ii conn transport connActive app addr th istatus src
where
wrappedRecv Connection { connRecv = recv } istatus slowlorisSize = do
bs <- recv
Expand Down Expand Up @@ -83,8 +95,8 @@ http1 settings ii conn transport app origAddr th bs0 = do

decodeAscii = map (chr . fromEnum) . BS.unpack

http1server :: Settings -> InternalInfo -> Connection -> Transport -> Application -> SockAddr -> T.Handle -> IORef Bool -> Source -> IO ()
http1server settings ii conn transport app addr th istatus src =
http1server :: Settings -> InternalInfo -> Connection -> Transport -> ConnActiveFlag -> Application -> SockAddr -> T.Handle -> IORef Bool -> Source -> IO ()
http1server settings ii conn transport connActive app addr th istatus src =
loop True `UnliftIO.catchAny` handler
where
handler e
Expand All @@ -98,7 +110,8 @@ http1server settings ii conn transport app addr th istatus src =
throwIO e

loop firstRequest = do
(req, mremainingRef, idxhdr, nextBodyFlush) <- recvRequest firstRequest settings conn ii th addr src transport
setConnActiveFlag connActive True
(req, mremainingRef, idxhdr, nextBodyFlush) <- recvRequest firstRequest settings conn ii th addr src transport connActive
keepAlive <- processRequest settings ii conn app th istatus src req mremainingRef idxhdr nextBodyFlush
`UnliftIO.catchAny` \e -> do
settingsOnException settings (Just req) e
Expand Down Expand Up @@ -219,3 +232,12 @@ flushBody src = loop
| BS.null bs -> return True
| toRead' >= 0 -> loop toRead'
| otherwise -> return False

-- | Used by a handler to indicate that its current computation can be safely
-- killed if the requesting connection is shutdown.
connectionIsInactive :: Request -> IO ()
connectionIsInactive req = do
case Vault.lookup connActiveFlagKey (vault req) of
Just flag -> setConnActiveFlag flag False
Nothing -> return ()

9 changes: 8 additions & 1 deletion warp/Network/Wai/Handler/Warp/Request.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ module Network.Wai.Handler.Warp.Request (
#ifdef MIN_VERSION_crypton_x509
, getClientCertificateKey
#endif
, connActiveFlagKey
, NoKeepAliveRequest (..)
) where

Expand Down Expand Up @@ -56,6 +57,7 @@ recvRequest :: Bool -- ^ first request on this connection?
-> SockAddr -- ^ Peer's address.
-> Source -- ^ Where HTTP request comes from.
-> Transport
-> ConnActiveFlag
-> IO (Request
,Maybe (I.IORef Int)
,IndexedHeader
Expand All @@ -65,7 +67,7 @@ recvRequest :: Bool -- ^ first request on this connection?
-- 'IndexedHeader' of HTTP request for internal use,
-- Body producing action used for flushing the request body

recvRequest firstRequest settings conn ii th addr src transport = do
recvRequest firstRequest settings conn ii th addr src transport connActive = do
hdrlines <- headerLines (settingsMaxTotalHeaderLength settings) firstRequest src
(method, unparsedPath, path, query, httpversion, hdr) <- parseHeaderLines hdrlines
let idxhdr = indexRequestHeader hdr
Expand All @@ -76,6 +78,7 @@ recvRequest firstRequest settings conn ii th addr src transport = do
rawPath = if settingsNoParsePath settings then unparsedPath else path
vaultValue = Vault.insert pauseTimeoutKey (Timeout.pause th)
$ Vault.insert getFileInfoKey (getFileInfo ii)
$ Vault.insert connActiveFlagKey connActive
#ifdef MIN_VERSION_crypton_x509
$ Vault.insert getClientCertificateKey (getTransportClientCertificate transport)
#endif
Expand Down Expand Up @@ -328,3 +331,7 @@ getClientCertificateKey :: Vault.Key (Maybe CertificateChain)
getClientCertificateKey = unsafePerformIO Vault.newKey
{-# NOINLINE getClientCertificateKey #-}
#endif

connActiveFlagKey :: Vault.Key ConnActiveFlag
connActiveFlagKey = unsafePerformIO Vault.newKey
{-# NOINLINE connActiveFlagKey #-}
16 changes: 16 additions & 0 deletions warp/Network/Wai/Handler/Warp/Run.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
module Network.Wai.Handler.Warp.Run where

import Control.Arrow (first)
import Control.Concurrent
import qualified Control.Exception
import Control.Exception (allowInterrupt)
import qualified Data.ByteString as S
Expand All @@ -21,6 +22,12 @@ import Network.Socket (gracefulClose)
#endif
import Network.Socket.BufferPool
import qualified Network.Socket.ByteString as Sock
#if MIN_VERSION_base(4,18,0)
-- For evtPeerClosed
import Network.Socket (withFdSocket)
import GHC.Event
import System.Posix.Types (Fd(Fd))
#endif
import Network.Wai
import System.Environment (lookupEnv)
import System.IO.Error (ioeGetErrorType)
Expand Down Expand Up @@ -59,6 +66,14 @@ socketConnection _ s = do
bufferPool <- newBufferPool 2048 16384
writeBuffer <- createWriteBuffer 16384
writeBufferRef <- newIORef writeBuffer
#if MIN_VERSION_base(4,18,0)
let registerPeerClosedCb = Just $ \cb -> withFdSocket s $ \fd -> do
Just mgr <- getSystemEventManager
_ <- registerFd mgr (\ _ _ -> cb) (Fd fd) evtPeerClosed OneShot
return ()
#else
let registerPeerClosedCb = Nothing
#endif
isH2 <- newIORef False -- HTTP/1.x
return Connection {
connSendMany = Sock.sendMany s
Expand All @@ -80,6 +95,7 @@ socketConnection _ s = do
, connRecvBuf = \_ _ -> return True -- obsoleted
, connWriteBuffer = writeBufferRef
, connHTTP2 = isH2
, connRegisterPeerClosedCb = registerPeerClosedCb
}
where
receive' sock pool = UnliftIO.handleIO handler $ receive sock pool
Expand Down
32 changes: 32 additions & 0 deletions warp/Network/Wai/Handler/Warp/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,15 @@ instance UnliftIO.Exception ExceptionInsideResponseBody

----------------------------------------------------------------

-- | Exception thrown when the iniating client of a connection being handled by
-- a worker closes its end of the connection.
data PeerClosedException = PeerClosedException
deriving (Show)

instance UnliftIO.Exception PeerClosedException

----------------------------------------------------------------

-- | Data type to abstract file identifiers.
-- On Unix, a file descriptor would be specified to make use of
-- the file descriptor cache.
Expand Down Expand Up @@ -125,6 +134,7 @@ data Connection = Connection {
, connWriteBuffer :: IORef WriteBuffer
-- | Is this connection HTTP/2?
, connHTTP2 :: IORef Bool
, connRegisterPeerClosedCb :: Maybe (IO () -> IO ())
}

getConnHTTP2 :: Connection -> IO Bool
Expand All @@ -144,6 +154,28 @@ data InternalInfo = InternalInfo {

----------------------------------------------------------------

-- | In some HTTP/1 applications (e.g. those where requests are pure queries
-- which imply no "effects") it can make sense to abort running handlers when
-- the write-side of the client's connection closed (via @shutdown(2)@) before
-- a response has been sent. To facilitate this use-case, each handler thread
-- carries a 'ConnActiveFlag' which dictates whether the handler's current
-- computation can be safely aborted if the connection is shutdown.
newtype ConnActiveFlag = ConnActiveFlag (UnliftIO.TVar Bool)

mkConnActiveFlag :: IO ConnActiveFlag
mkConnActiveFlag = ConnActiveFlag <$> UnliftIO.newTVarIO True

setConnActiveFlag :: ConnActiveFlag -> Bool -> IO ()
setConnActiveFlag (ConnActiveFlag v) active = UnliftIO.atomically $
UnliftIO.writeTVar v active

waitUntilConnInactive :: ConnActiveFlag -> IO ()
waitUntilConnInactive (ConnActiveFlag v) = UnliftIO.atomically $ do
active <- UnliftIO.readTVar v
when active UnliftIO.retrySTM

----------------------------------------------------------------

-- | Type for input streaming.
data Source = Source !(IORef ByteString) !(IO ByteString)

Expand Down

0 comments on commit b58f7b5

Please sign in to comment.