diff --git a/semialign/src/Data/Semialign.hs b/semialign/src/Data/Semialign.hs index a70f83c..f784f18 100644 --- a/semialign/src/Data/Semialign.hs +++ b/semialign/src/Data/Semialign.hs @@ -16,6 +16,14 @@ module Data.Semialign ( lpadZip, lpadZipWith, rpadZip, rpadZipWith, alignVectorWith, + -- * Unzip definition helpers + UnzipStrictSpineStrictPairs (..), + UnzipStrictSpineLazyPairs (..), + UnzipLazySpineLazyPairs (..), + unzipWithStrictSpineStrictPairs, + unzipWithStrictSpineLazyPairs, + unzipStrictSpineLazyPairs, + unzipWithLazySpineLazyPairs, ) where import Data.Semialign.Internal diff --git a/semialign/src/Data/Semialign/Internal.hs b/semialign/src/Data/Semialign/Internal.hs index afc0259..f31c8cb 100644 --- a/semialign/src/Data/Semialign/Internal.hs +++ b/semialign/src/Data/Semialign/Internal.hs @@ -1,5 +1,7 @@ {-# LANGUAGE CPP #-} {-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE DeriveFoldable #-} +{-# LANGUAGE DeriveTraversable #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE FunctionalDependencies #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} @@ -14,7 +16,7 @@ module Data.Semialign.Internal where import Prelude (Bool (..), Either (..), Eq (..), Functor (fmap), Int, Maybe (..), Monad (..), Ord (..), Ordering (..), String, error, flip, fst, id, - maybe, snd, uncurry, ($), (++), (.)) + maybe, snd, uncurry, ($), (++), (.), Traversable, Foldable) import qualified Prelude as Prelude @@ -590,6 +592,95 @@ instance Biapplicative SBPair where biliftA2 f g (SBPair (a, b)) (SBPair (c, d)) = SBPair (f a c, g b d) +-- A copy of (,) with a lazier biliftA2 +newtype LBPair a b = LBPair { unLBPair :: (a, b) } + +instance Bifunctor LBPair where + bimap f g (LBPair ab) = LBPair (f a, g b) + where + -- Is this enough? I'm not sure. The danger is if + -- the call inlines and `ab = (p, q)` inlines, but for whatever + -- reason we end up with something like + -- + -- a = p + -- b = case ab of (_, q) -> q + -- + -- I've seen something vaguely like that before, in Data.List.transpose, + -- but I don't remember the details. If necessary, we can use + -- `GHC.Exts.noinline` on `ab` for `base >= 4.15`, or some kind of shim + -- elsewhere, but then we'll also want a rewrite rule + -- + -- bimap f g (LBPair a b) = LBPair (f a, g b) + -- + -- for when we get lucky. + {-# NOINLINE a #-} + {-# NOINLINE b #-} + (a, b) = ab + +instance Biapplicative LBPair where + bipure a b = LBPair (a, b) + biliftA2 f g (LBPair ab) (LBPair cd) = + LBPair (f a c, g b d) + where + {-# NOINLINE a #-} + {-# NOINLINE b #-} + {-# NOINLINE c #-} + {-# NOINLINE d #-} + (a, b) = ab + (c, d) = cd + +newtype UnzipStrictSpineStrictPairs t a = + UnzipStrictSpineStrictPairs { getUnzipStrictSpineStrictPairs :: t a } + deriving (Functor, Foldable, Traversable, Semialign, Align, Zip) + +instance (Zip t, Traversable t) => Unzip (UnzipStrictSpineStrictPairs t) where + unzipWith = unzipWithStrictSpineStrictPairs + +newtype UnzipStrictSpineLazyPairs t a = + UnzipStrictSpineLazyPairs { getUnzipStrictSpineLazyPairs :: t a } + deriving (Functor, Foldable, Traversable, Semialign, Align, Zip) + +instance (Zip t, Traversable t) => Unzip (UnzipStrictSpineLazyPairs t) where + unzipWith = unzipWithStrictSpineLazyPairs + unzip = unzipStrictSpineLazyPairs + +newtype UnzipLazySpineLazyPairs t a = + UnzipLazySpineLazyPairs { getUnzipLazySpineLazyPairs :: t a } + deriving (Functor, Foldable, Traversable, Semialign, Align, Zip) + +instance (Zip t, Traversable t) => Unzip (UnzipLazySpineLazyPairs t) where + unzipWith = unzipWithLazySpineLazyPairs + +unzipWithStrictSpineStrictPairs :: Traversable t + => (c -> (a, b)) -> t c -> (t a, t b) +unzipWithStrictSpineStrictPairs f = unSBPair . traverseBia (SBPair . f) + +unzipWithStrictSpineLazyPairs :: Traversable t + => (c -> (a, b)) -> t c -> (t a, t b) +unzipWithStrictSpineLazyPairs f = unSBPair . traverseBia (SBPair . foo) + where + foo c = let + {-# NOINLINE fc #-} + {-# NOINLINE a #-} + {-# NOINLINE b #-} + fc = f c + (a, b) = fc + in (a, b) + +unzipStrictSpineLazyPairs :: Traversable t + => t (a, b) -> (t a, t b) +unzipStrictSpineLazyPairs = unSBPair . traverseBia (SBPair . foo) + where + foo ab = let + {-# NOINLINE a #-} + {-# NOINLINE b #-} + (a, b) = ab + in (a, b) + +unzipWithLazySpineLazyPairs :: Traversable t + => (c -> (a, b)) -> t c -> (t a, t b) +unzipWithLazySpineLazyPairs f = unLBPair . traverseBia (LBPair . f) + instance Ord k => Unzip (Map k) where unzip = unzipDefault instance Ord k => Zip (Map k) where