diff --git a/src/Control/Monad/Free.hs b/src/Control/Monad/Free.hs index 622c6c1..10fc974 100644 --- a/src/Control/Monad/Free.hs +++ b/src/Control/Monad/Free.hs @@ -6,6 +6,7 @@ {-# LANGUAGE Rank2Types #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE Safe #-} +{-# LANGUAGE TupleSections #-} ----------------------------------------------------------------------------- -- | @@ -33,6 +34,11 @@ module Control.Monad.Free , cutoff , unfold , unfoldM + , after + , before + , weave + , weaveMax + , weaveMin , _Pure, _Free ) where @@ -395,3 +401,40 @@ _Free = dimap unfree (either pure (fmap Free)) . right' unfree (Pure x) = Left (Pure x) {-# INLINE unfree #-} {-# INLINE _Free #-} + +before :: Functor m => m () -> Free m a -> Free m a +before mu = go + where + go = iterM $ \mfa -> liftF mu *> wrap (fmap go mfa) + +after :: Functor m => m () -> Free m a -> Free m a +after mu = go + where + go = iterM $ \mfa -> wrap $ flip fmap mfa $ \fa' -> liftF mu *> go fa' + +weave + :: forall f a b c + . Functor f + => ( a -> b -> Free f c) + -> ( a -> f (Free f b) -> Free f c) + -> (f (Free f a) -> b -> Free f c) + -> Free f a + -> Free f b + -> Free f c +weave end endA endB = go + where + go fa fb = case (fa, fb) of + (Free ma, Free mb) -> join $ liftA2 go (liftF ma) (liftF mb) + (Pure a, Pure b) -> end a b + (Pure a, Free mb) -> endA a mb + (Free ma, Pure b) -> endB ma b + +weaveMax :: Functor f => Free f a -> Free f b -> Free f (a,b) +weaveMax = weave + (curry Pure) + (\a fb -> fmap (a,) (Free fb)) + (\fa b -> fmap (,b) (Free fa)) + +weaveMin :: Functor f => Free f a -> Free f b -> Free f () +weaveMin = weave stop stop stop + where stop _ _ = pure ()