Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Retryable transactions + async exception handling #1482

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions persistent-postgresql/ChangeLog.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Changelog for persistent-postgresql

## 2.13.6.0

* [#1482](https://github.com/yesodweb/persistent/pull/1482)
* Add `isSerializationFailure` and `isDeadlockDetected` exception predicates

## 2.13.5.2

* [#1471](https://github.com/yesodweb/persistent/pull/1471)
Expand Down
41 changes: 36 additions & 5 deletions persistent-postgresql/Database/Persist/Postgresql.hs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ module Database.Persist.Postgresql
, createPostgresqlPoolModified
, createPostgresqlPoolModifiedWithVersion
, createPostgresqlPoolWithConf

, isSerializationFailure
, isDeadlockDetected

, module Database.Persist.Sql
, ConnectionString
, HandleUpdateCollision
Expand Down Expand Up @@ -77,13 +81,14 @@ import qualified Database.PostgreSQL.Simple.Transaction as PG
import qualified Database.PostgreSQL.Simple.Types as PG

import Control.Arrow
import Control.Exception (Exception, throw, throwIO)
import Control.Exception
(Exception(fromException), SomeException, throw, throwIO)
import Control.Monad
import Control.Monad.Except
import Control.Monad.IO.Unlift (MonadIO(..), MonadUnliftIO)
import Control.Monad.Logger (MonadLoggerIO, runNoLoggingT)
import Control.Monad.Trans.Reader (ReaderT(..), asks, runReaderT)
import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.Reader (ReaderT(..), asks, runReaderT)
#if !MIN_VERSION_base(4,12,0)
import Control.Monad.Trans.Reader (withReaderT)
#endif
Expand All @@ -102,8 +107,8 @@ import qualified Data.Conduit.List as CL
import Data.Data (Data)
import Data.Either (partitionEithers)
import Data.Function (on)
import Data.IORef
import Data.Int (Int64)
import Data.IORef
import Data.List (find, foldl', groupBy, sort)
import qualified Data.List as List
import Data.List.NonEmpty (NonEmpty)
Expand All @@ -122,12 +127,13 @@ import System.Environment (getEnvironment)
#if MIN_VERSION_base(4,12,0)
import Database.Persist.Compatible
#endif
import qualified Data.Vault.Strict as Vault
import Database.Persist.Postgresql.Internal
import Database.Persist.Sql
import qualified Database.Persist.Sql.Util as Util
import Database.Persist.SqlBackend
import Database.Persist.SqlBackend.StatementCache (StatementCache, mkSimpleStatementCache, mkStatementCache)
import qualified Data.Vault.Strict as Vault
import Database.Persist.SqlBackend.StatementCache
(StatementCache, mkSimpleStatementCache, mkStatementCache)
import System.IO.Unsafe (unsafePerformIO)

-- | A @libpq@ connection string. A simple example of connection
Expand Down Expand Up @@ -1953,6 +1959,31 @@ createRawPostgresqlPoolWithConf conf hooks = do
modConn = pgConfHooksAfterCreate hooks
createSqlPoolWithConfig (open' modConn getVer withRawConnection (pgConnStr conf)) (postgresConfToConnectionPoolConfig conf)

-- | An exception predicate checking for a PostgreSQL serialization error, i.e.
-- a @SQLSTATE@ error code of @"40001"@ (@serialization_failure@).
--
-- This error can occur when concurrent transactions modify the same row(s) at
-- serializable isolation level.
--
-- This predicate is intended for use with 'runSqlPoolWithExtensibleHooksRetry'.
--
-- @since 2.13.6.0
isSerializationFailure :: SomeException -> Bool
isSerializationFailure ex
| Just sqlError <- fromException ex = PG.isSerializationError sqlError
| otherwise = False

-- | An exception predicate checking for a PostgreSQL deadlock detected error,
-- i.e. a @SQLSTATE@ error code of @"40P01"@ (@deadlock_detected@).
--
-- This predicate is intended for use with 'runSqlPoolWithExtensibleHooksRetry'.
--
-- @since 2.13.6.0
isDeadlockDetected :: SomeException -> Bool
isDeadlockDetected ex
| Just sqlError <- fromException ex = PG.sqlState sqlError == "40P01"
| otherwise = False

#if MIN_VERSION_base(4,12,0)
instance (PersistCore b) => PersistCore (RawPostgresql b) where
newtype BackendKey (RawPostgresql b) = RawPostgresqlKey { unRawPostgresqlKey :: BackendKey (Compatible b (RawPostgresql b)) }
Expand Down
5 changes: 4 additions & 1 deletion persistent-postgresql/persistent-postgresql.cabal
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
name: persistent-postgresql
version: 2.13.5.2
version: 2.13.6.0
license: MIT
license-file: LICENSE
author: Felipe Lessa, Michael Snoyman <[email protected]>
Expand Down Expand Up @@ -58,6 +58,8 @@ test-suite test
UpsertWhere
ImplicitUuidSpec
MigrationReferenceSpec
AsyncExceptionsTest
RetryableTransactionsTest
ghc-options: -Wall

build-depends: base >= 4.9 && < 5
Expand Down Expand Up @@ -86,6 +88,7 @@ test-suite test
, unliftio
, unordered-containers
, vector
, postgresql-simple
default-language: Haskell2010

executable conn-kill
Expand Down
196 changes: 196 additions & 0 deletions persistent-postgresql/test/AsyncExceptionsTest.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module AsyncExceptionsTest
( specs
) where

import Control.Concurrent
( ThreadId
, forkIO
, killThread
, myThreadId
, newEmptyMVar
, putMVar
, takeMVar
)
import Control.Exception (MaskingState(MaskedUninterruptible), getMaskingState)
import Data.Function ((&))
import Database.Persist.SqlBackend.SqlPoolHooks (modifyRunOnException)
import GHC.Stack (SrcLoc, callStack, getCallStack)
import HookCounts
( HookCountRefs(..)
, HookCounts(..)
, hookCountsShouldBe
, newHookCountRefs
, trackHookCounts
)
import Init (aroundAll_)
import PgInit
( Filter
, HasCallStack
, MonadIO(..)
, PersistQueryWrite(deleteWhere)
, PersistStoreWrite(insert_)
, ReaderT
, RunConnArgs(sqlPoolHooks)
, Spec
, SqlBackend
, Text
, defaultRunConnArgs
, describe
, it
, mkMigrate
, mkPersist
, persistLowerCase
, runConnUsing
, runConn_
, runMigrationSilent
, share
, sqlSettings
, void
)
import Test.HUnit.Lang (FailureReason(Reason), HUnitFailure(HUnitFailure))
import UnliftIO.Exception (bracket_, throwTo)

share
[mkPersist sqlSettings, mkMigrate "asyncExceptionsTestMigrate"]
[persistLowerCase|
AsyncExceptionTestData
stuff Text
Primary stuff
deriving Eq Show
|]

setup :: IO ()
setup = runConn_ $ void $ runMigrationSilent asyncExceptionsTestMigrate

teardown :: IO ()
teardown = runConn_ cleanDB

cleanDB :: forall m. (MonadIO m) => ReaderT SqlBackend m ()
cleanDB = deleteWhere ([] :: [Filter AsyncExceptionTestData])

specs :: Spec
specs = aroundAll_ (bracket_ setup teardown) $ do
describe "Testing async exceptions" $ do
it "runOnException hook is executed" $ do
insertDoneRef <- newEmptyMVar
shouldProceedRef <- newEmptyMVar

hookCountRefs <- newHookCountRefs
runConnArgs <- mkRunConnArgs hookCountRefs

threadId <- forkIO $ do
runConnUsing runConnArgs $ do
insert_ $ AsyncExceptionTestData "bloorp"
liftIO $ do
-- "Child" thread signals to the main thread that the insert was
-- executed.
putMVar insertDoneRef ()
-- "Child" thread waits around indefinitely on this @MVar@.
-- @shouldProceedRef@ is intentionally never written to in this test
-- so that the "child" thread is blocked here until the main thread
-- kills it via async exception. See the remaining comments in this
-- test for more detail.
takeMVar shouldProceedRef

-- Main thread waits here for the signal from the "child" thread telling
-- us the DB insert has been performed. More specifically, we know the
-- following events have occurred in the "child" thread after this
-- @takeMVar@ call succeeds:
--
-- 1) The @alterBackend@ hook was executed
-- 2) The @runBefore@ hook was executed
-- 3) The insert of our test data was executed
-- 4) Execution is blocked right after the insert, so either of the
-- @runOnException@ or @runAfter@ hooks have not yet been executed.
takeMVar insertDoneRef

-- Verify that the actual hook execution in the "child" thread is as
-- described previously.
hookCountRefs `hookCountsShouldBe`
HookCounts
{ alterBackendCount = 1
, runBeforeCount = 1
, runOnExceptionCount = 0
, runAfterCount = 0
}

-- Main thread kills the "child" thread via async exception while the
-- "child" thread is still in its user-specified DB action, which should
-- cause the @runOnException@ hook to fire, rolling back the transaction.
--
-- Note that the @runOnException@ hook produced by @mkRunConnArgs@ also
-- ensures the handler's masking state is uninterruptible. See
-- @mkRunConnArgs@ for that check's implementation.
killThread threadId

-- Verify that the @runOnException@ hook was indeed executed.
hookCountRefs `hookCountsShouldBe`
HookCounts
{ alterBackendCount = 1
, runBeforeCount = 1
, runOnExceptionCount = 1
, runAfterCount = 0
}

-- | Build a 'RunConnArgs' value for use in this module's specs.
--
-- This function should only be called from the main thread.
mkRunConnArgs
:: forall m
. (MonadIO m)
=> HookCountRefs
-> m (RunConnArgs m)
mkRunConnArgs hookCountRefs = do
threadId <- liftIO myThreadId
pure $ (defaultRunConnArgs @m)
{ sqlPoolHooks =
trackHookCounts hookCountRefs (sqlPoolHooks defaultRunConnArgs)
& flip modifyRunOnException (\origRunOnException conn level ex -> do
-- It's sneaky to make this masking state assertion here rather
-- than explicitly in a spec. At this time, it feels a bit cleaner
-- to keep this assertion tucked away in here. The downside is
-- that this function does not run in the main thread, so we must
-- throw an expectation failure into the main thread on assertion
-- failure to have it reported by Hspec.
liftIO $
getMaskingState >>= \case
MaskedUninterruptible -> pure ()
_ ->
throwExpectationFailureTo
threadId
"Expected runOnException masking to be uninterruptible"

origRunOnException conn level ex
)
}

throwExpectationFailureTo
:: HasCallStack
=> ThreadId
-> String
-> IO ()
throwExpectationFailureTo threadId msg =
throwTo threadId $ HUnitFailure location $ Reason msg

location :: HasCallStack => Maybe SrcLoc
location = case reverse $ getCallStack callStack of
(_, loc) : _ -> Just loc
[] -> Nothing
Loading