Skip to content

Commit

Permalink
Fix instances and improve haddock
Browse files Browse the repository at this point in the history
  • Loading branch information
lehins committed Feb 3, 2024
1 parent 5fb946b commit c6e68a6
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 33 deletions.
67 changes: 45 additions & 22 deletions src/System/Random/Seed.hs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ module System.Random.Seed
, unSeed
, mkSeedFromByteString
, unSeedToByteString
, withSeed
, withSeedM
, withSeedFile
, nonEmptyToSeed
, nonEmptyFromSeed
Expand All @@ -42,6 +44,7 @@ import Data.Bits
import qualified Data.ByteString as BS
import qualified Data.ByteString.Short.Internal as SBS (fromShort, toShort)
import Data.Coerce
import Data.Functor.Identity (runIdentity)
import Data.List.NonEmpty as NE (NonEmpty(..), nonEmpty, toList)
import Data.Typeable
import Data.Word
Expand All @@ -50,14 +53,6 @@ import System.Random.Internal
import qualified System.Random.SplitMix as SM
import qualified System.Random.SplitMix32 as SM32

data FiveByteGen = FiveByteGen Word8 Word32 deriving Show

instance SeedGen FiveByteGen where
type SeedSize FiveByteGen = 5
seedGen64 (w64 :| _) = FiveByteGen (fromIntegral (w64 `shiftR` 32)) (fromIntegral w64)
unseedGen64 (FiveByteGen x1 x4) =
let w64 = (fromIntegral x1 `shiftL` 32) .|. fromIntegral x4
in (w64 :| [])

-- | Interface for converting a pure pseudo-random number generator to and from non-empty
-- sequence of bytes. Seeds are stored in Little-Endian order regardless of the platform
Expand Down Expand Up @@ -104,15 +99,12 @@ instance SeedGen FiveByteGen where
-- @since 1.3.0
class (KnownNat (SeedSize g), 1 <= SeedSize g, Typeable g) => SeedGen g where
-- | Number of bytes that is required for storing the full state of a pseudo-random
-- number generator. It should be big enough to satisfy the roundtrip properies:
-- number generator. It should be big enough to satisfy the roundtrip property:
--
-- @
-- > seedGen (unseedGen gen) == gen
-- @
--
-- @
-- > unseedGen (seedGen seed) == seed
-- @
type SeedSize g :: Nat
{-# MINIMAL (seedGen, unseedGen)|(seedGen64, unseedGen64) #-}

Expand Down Expand Up @@ -150,13 +142,13 @@ class (KnownNat (SeedSize g), 1 <= SeedSize g, Typeable g) => SeedGen g where

instance SeedGen StdGen where
type SeedSize StdGen = SeedSize SM.SMGen
seedGen = seedGen . coerce
unseedGen = coerce . unseedGen
seedGen = coerce (seedGen :: Seed SM.SMGen -> SM.SMGen)
unseedGen = coerce (unseedGen :: SM.SMGen -> Seed SM.SMGen)

instance SeedGen g => SeedGen (StateGen g) where
type SeedSize (StateGen g) = SeedSize g
seedGen = seedGen . coerce
unseedGen = coerce . unseedGen
seedGen = coerce (seedGen :: Seed g -> g)
unseedGen = coerce (unseedGen :: g -> Seed g)

instance SeedGen SM.SMGen where
type SeedSize SM.SMGen = 16
Expand Down Expand Up @@ -189,15 +181,15 @@ instance SeedGen SM32.SMGen where
freezeMutableByteArray mba


-- | Get the expected size in bytes of the `Seed`
-- | Get the expected size of the `Seed` in number bytes
--
-- @since 1.3.0
seedSize :: forall g. SeedGen g => Int
seedSize = fromIntegral $ natVal (Proxy :: Proxy (SeedSize g))

-- | Construct a `Seed` from a `ByteArray` of expected length. Whenever `ByteArray` does
-- not match the `SeedSize` specified by the pseudo-random generator, this function will
-- return `Nothing`
-- `F.fail`.
--
-- @since 1.3.0
mkSeed :: forall g m. (SeedGen g, F.MonadFail m) => ByteArray -> m (Seed g)
Expand All @@ -211,9 +203,33 @@ mkSeed ba = do
++ show (genTypeName @g)
pure $ Seed ba

-- | Helper function that allows for operating directly on the `Seed`, while supplying a
-- function that uses the pseudo-random number generator that is constructed from that
-- `Seed`.
--
-- ====__Example__
--
-- >>> :set -XTypeApplications
-- >>> withSeed (nonEmptyToSeed (pure 2024) :: Seed StdGen) (random @Int)
-- (1039666877624726199,Seed [0xe9, 0x07, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00])
--
-- @since 1.3.0
withSeed :: SeedGen g => Seed g -> (g -> (a, g)) -> (a, Seed g)
withSeed seed f = runIdentity (withSeedM seed (pure . f))

-- | Same as `withSeed`, except it is useful with monadic computation and frozen generators.
--
-- See `System.Random.Stateful.withMutableSeedGen` for a helper that also handles seeds
-- for mutable pseduo-random number generators.
--
-- @since 1.3.0
withSeedM :: (SeedGen g, Functor f) => Seed g -> (g -> f (a, g)) -> f (a, Seed g)
withSeedM seed f = fmap unseedGen <$> f (seedGen seed)

-- | This is a function that shows the name of the generator type, which is useful for
-- error reporting.
--
-- @since 1.3.0
genTypeName :: forall g. SeedGen g => String
genTypeName = show (typeOf (Proxy @g))

Expand Down Expand Up @@ -246,11 +262,13 @@ withSeedFile :: (SeedGen g, MonadIO m) => FilePath -> (g -> m (a, g)) -> m a
withSeedFile fileName f = do
bs <- liftIO $ BS.readFile fileName
seed <- liftIO $ mkSeedFromByteString bs
(res, gen) <- f $ seedGen seed
liftIO $ BS.writeFile fileName $ unSeedToByteString $ unseedGen gen
(res, seed') <- withSeedM seed f
liftIO $ BS.writeFile fileName $ unSeedToByteString seed'
pure res


-- | Construct a seed from a list of 64-bit words. At most `SeedSize` many bytes will be used.
--
-- @since 1.3.0
nonEmptyToSeed :: forall g. SeedGen g => NonEmpty Word64 -> Seed g
nonEmptyToSeed xs = Seed $ runST $ do
let n = seedSize @g
Expand All @@ -262,12 +280,17 @@ nonEmptyToSeed xs = Seed $ runST $ do
w:ws -> w <$ put ws
freezeMutableByteArray mba

-- | Convert a `Seed` to a list of 64bit words.
--
-- @since 1.3.0
nonEmptyFromSeed :: forall g. SeedGen g => Seed g -> NonEmpty Word64
nonEmptyFromSeed (Seed ba) =
case nonEmpty $ reverse $ goWord64 0 [] of
Just ne -> ne
Nothing -> -- Seed is at least 1 byte in size, so it can't be empty
error $ "Impossible: Seed must be at least: "
error $ "Impossible: Seed for "
++ genTypeName @g
++ " must be at least: "
++ show (seedSize @g)
++ " bytes, but got "
++ show n
Expand Down
20 changes: 9 additions & 11 deletions src/System/Random/Stateful.hs
Original file line number Diff line number Diff line change
Expand Up @@ -300,9 +300,7 @@ withMutableGen_ fg action = thawGen fg >>= action
--
-- @since 1.3.0
withMutableSeedGen :: (SeedGen g, ThawedGen g m) => Seed g -> (MutableGen g m -> m a) -> m (a, Seed g)
withMutableSeedGen seed f = do
(res, frozenGen) <- withMutableGen (seedGen seed) f
pure (res, unseedGen frozenGen)
withMutableSeedGen seed f = withSeedM seed (`withMutableGen` f)

-- | Just like `withMutableSeedGen`, except it doesn't return the final generator, only
-- the resulting value. This is slightly more efficient, since it doesn't incur overhead
Expand Down Expand Up @@ -375,8 +373,8 @@ newtype AtomicGen g = AtomicGen { unAtomicGen :: g}
-- Standalone definition due to GHC-8.0 not supporting deriving with associated type families
instance SeedGen g => SeedGen (AtomicGen g) where
type SeedSize (AtomicGen g) = SeedSize g
seedGen = seedGen . coerce
unseedGen = coerce . unseedGen
seedGen = coerce (seedGen :: Seed g -> g)
unseedGen = coerce (unseedGen :: g -> Seed g)

-- | Creates a new 'AtomicGenM'.
--
Expand Down Expand Up @@ -472,8 +470,8 @@ newtype IOGen g = IOGen { unIOGen :: g }
-- Standalone definition due to GHC-8.0 not supporting deriving with associated type families
instance SeedGen g => SeedGen (IOGen g) where
type SeedSize (IOGen g) = SeedSize g
seedGen = seedGen . coerce
unseedGen = coerce . unseedGen
seedGen = coerce (seedGen :: Seed g -> g)
unseedGen = coerce (unseedGen :: g -> Seed g)

-- | Creates a new 'IOGenM'.
--
Expand Down Expand Up @@ -548,8 +546,8 @@ newtype STGen g = STGen { unSTGen :: g }
-- Standalone definition due to GHC-8.0 not supporting deriving with associated type families
instance SeedGen g => SeedGen (STGen g) where
type SeedSize (STGen g) = SeedSize g
seedGen = seedGen . coerce
unseedGen = coerce . unseedGen
seedGen = coerce (seedGen :: Seed g -> g)
unseedGen = coerce (unseedGen :: g -> Seed g)

-- | Creates a new 'STGenM'.
--
Expand Down Expand Up @@ -649,8 +647,8 @@ newtype TGen g = TGen { unTGen :: g }
-- Standalone definition due to GHC-8.0 not supporting deriving with associated type families
instance SeedGen g => SeedGen (TGen g) where
type SeedSize (TGen g) = SeedSize g
seedGen = seedGen . coerce
unseedGen = coerce . unseedGen
seedGen = coerce (seedGen :: Seed g -> g)
unseedGen = coerce (unseedGen :: g -> Seed g)

-- | Creates a new 'TGenM' in `STM`.
--
Expand Down

0 comments on commit c6e68a6

Please sign in to comment.