diff --git a/free.cabal b/free.cabal index 765316b..4b8716f 100644 --- a/free.cabal +++ b/free.cabal @@ -89,6 +89,7 @@ library transformers-base >= 0.4 && < 0.5, template-haskell >= 2.7.0.0 && < 3, exceptions >= 0.6 && < 0.11, + monad-control >= 1 && < 1.1, containers < 0.7 if !impl(ghc >= 8.2) diff --git a/src/Control/Monad/Trans/Free.hs b/src/Control/Monad/Trans/Free.hs index 555cc91..b871a6f 100644 --- a/src/Control/Monad/Trans/Free.hs +++ b/src/Control/Monad/Trans/Free.hs @@ -4,6 +4,7 @@ {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE Rank2Types #-} +{-# LANGUAGE TypeFamilies #-} #if __GLASGOW_HASKELL__ >= 707 {-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE DeriveGeneric #-} @@ -57,6 +58,8 @@ import Control.Monad (liftM, MonadPlus(..), ap, join) import Control.Monad.Base (MonadBase(..)) import Control.Monad.Catch (MonadThrow(..), MonadCatch(..)) import Control.Monad.Trans.Class +import Control.Monad.Trans.Control (MonadTransControl(..), MonadBaseControl(..), + ComposeSt, defaultLiftBaseWith, defaultRestoreM) import Control.Monad.Free.Class import qualified Control.Monad.Fail as Fail import Control.Monad.IO.Class @@ -325,6 +328,39 @@ instance (Functor f, MonadBase b m) => MonadBase b (FreeT f m) where liftBase = lift . liftBase {-# INLINE liftBase #-} +{- +This instance must satisfy: +* liftWith . const . return = return +liftWith . const . return $ x + = lift $ (const $ return x) joinFreeT + = lift (return x) + = return x + +* liftWith (const (m >>= f)) = liftWith (const m) >>= liftWith . const . f +liftWith (const m) >>= liftWith . const . f + = lift (const m (joinFreeT)) >>= \x -> lift $ const (f x) joinFreeT + = lift m >>= lift . f + = lift (m >>= f) + = lift (const (m >>= f) joinFreeT) + = liftWith (const (m >>= f)) +* liftWith (\run -> run t) >>= restoreT . return = t +liftWith (\run -> run t) >>= restoreT . return + = lift (joinFreeT t) >>= lift . return >>= hoistFreeT (return . runIdentity) + = lift (joinFreeT t) >>= hoistFreeT (return . runIdentity) + = t +-} +instance (Traversable f) => MonadTransControl (FreeT f) where + type StT (FreeT f) a = Free f a + liftWith mkFreeT = lift $ mkFreeT joinFreeT + {-# INLINE liftWith #-} + restoreT mstt = lift mstt >>= hoistFreeT (return . runIdentity) + {-# INLINE restoreT #-} + +instance (Traversable f, MonadBaseControl b m) => MonadBaseControl b (FreeT f m) where + type StM (FreeT f m) a = ComposeSt (FreeT f) m a + liftBaseWith = defaultLiftBaseWith + restoreM = defaultRestoreM + instance (Functor f, MonadReader r m) => MonadReader r (FreeT f m) where ask = lift ask {-# INLINE ask #-} diff --git a/src/Control/Monad/Trans/Free/Church.hs b/src/Control/Monad/Trans/Free/Church.hs index d613300..d4a78f7 100644 --- a/src/Control/Monad/Trans/Free/Church.hs +++ b/src/Control/Monad/Trans/Free/Church.hs @@ -3,6 +3,7 @@ {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE TypeFamilies #-} #include "free-common.h" ----------------------------------------------------------------------------- @@ -48,9 +49,12 @@ module Control.Monad.Trans.Free.Church import Control.Applicative import Control.Category ((<<<), (>>>)) import Control.Monad +import Control.Monad.Base (MonadBase(..)) import Control.Monad.Catch (MonadCatch(..), MonadThrow(..)) import Control.Monad.Identity import Control.Monad.Trans.Class +import Control.Monad.Trans.Control (MonadTransControl(..), MonadBaseControl(..), + ComposeSt, defaultLiftBaseWith, defaultRestoreM) import Control.Monad.IO.Class import Control.Monad.Reader.Class import Control.Monad.Writer.Class @@ -156,6 +160,22 @@ instance (MonadIO m) => MonadIO (FT f m) where liftIO = lift . liftIO {-# INLINE liftIO #-} +instance MonadBase b m => MonadBase b (FT f m) where + liftBase = lift . liftBase + {-# INLINE liftBase #-} + +instance (Traversable f) => MonadTransControl (FT f) where + type StT (FT f) a = F f a + liftWith mkFT = lift $ mkFT joinFT + {-# INLINE liftWith #-} + restoreT mstt = lift mstt >>= hoistFT (return . runIdentity) + {-# INLINE restoreT #-} + +instance (MonadBaseControl b m, Traversable f) => MonadBaseControl b (FT f m) where + type StM (FT f m) a = ComposeSt (FT f) m a + liftBaseWith = defaultLiftBaseWith + restoreM = defaultRestoreM + instance (Functor f, MonadError e m) => MonadError e (FT f m) where throwError = lift . throwError {-# INLINE throwError #-}