From f23a1732bee893fb12a2a04d75f462802e4cb008 Mon Sep 17 00:00:00 2001 From: Alexey Kuleshevich Date: Sun, 28 Jan 2024 14:34:20 +0100 Subject: [PATCH] Add intial implementation of SeedGen --- CHANGELOG.md | 2 + bench-legacy/SimpleRNGBench.hs | 16 ++- src/System/Random.hs | 5 + src/System/Random/Internal.hs | 219 ++++++++++++++++++++++++++++++--- src/System/Random/Stateful.hs | 10 +- test/Spec.hs | 8 ++ test/Spec/Stateful.hs | 2 +- 7 files changed, 237 insertions(+), 25 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0097b0c7..402ab37a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,7 @@ # 1.3.0 +* Add `Seed`, `SeedGen`, `seedSize`, `mkSeed` and `unSeed`: + [#162](https://github.com/haskell/random/pull/162) * Add `SplitGen` and `splitGen`: [#160](https://github.com/haskell/random/pull/160) * Add `shuffleList` and `shuffleListM`: [#140](https://github.com/haskell/random/pull/140) * Add `mkStdGen64`: [#155](https://github.com/haskell/random/pull/155) diff --git a/bench-legacy/SimpleRNGBench.hs b/bench-legacy/SimpleRNGBench.hs index b941a1b8..dc076c5b 100644 --- a/bench-legacy/SimpleRNGBench.hs +++ b/bench-legacy/SimpleRNGBench.hs @@ -1,8 +1,13 @@ -{-# LANGUAGE BangPatterns, ScopedTypeVariables, ForeignFunctionInterface #-} +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE ForeignFunctionInterface #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UndecidableInstances #-} {-# OPTIONS_GHC -fwarn-unused-imports #-} -- | A simple script to do some very basic timing of the RNGs. - module Main where import System.Exit (exitSuccess, exitFailure) @@ -80,13 +85,18 @@ measureFreq = do -- Test overheads without actually generating any random numbers: data NoopRNG = NoopRNG +instance SeedGen NoopRNG where + type SeedSize NoopRNG = 1 + seedGen = error "NoopRNG" + unseedGen = error "NoopRNG" instance RandomGen NoopRNG where next g = (0, g) genRange _ = (0, 0) split g = (g, g) -- An RNG generating only 0 or 1: -data BinRNG = BinRNG StdGen +newtype BinRNG = BinRNG StdGen + deriving (SeedGen) instance RandomGen BinRNG where next (BinRNG g) = (x `mod` 2, BinRNG g') where diff --git a/src/System/Random.hs b/src/System/Random.hs index 356c69fe..dbb71576 100644 --- a/src/System/Random.hs +++ b/src/System/Random.hs @@ -30,6 +30,11 @@ module System.Random , genWord64R , unsafeUniformFillMutableByteArray ) + , SeedGen (..) + , Seed + , mkSeed + , unSeed + , seedSize , SplitGen (splitGen) , uniform , uniformR diff --git a/src/System/Random/Internal.hs b/src/System/Random/Internal.hs index 9d0c33a5..d921ae3b 100644 --- a/src/System/Random/Internal.hs +++ b/src/System/Random/Internal.hs @@ -1,20 +1,25 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} +{-# LANGUAGE DataKinds #-} {-# LANGUAGE DefaultSignatures #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE GHCForeignImportPrim #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE MagicHash #-} +{-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE Trustworthy #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilyDependencies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UnboxedTuples #-} {-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE UndecidableSuperClasses #-} {-# LANGUAGE UnliftedFFITypes #-} -{-# LANGUAGE TypeFamilyDependencies #-} {-# OPTIONS_HADDOCK hide, not-home #-} -- | @@ -29,6 +34,15 @@ module System.Random.Internal (-- * Pure and monadic pseudo-random number generator interfaces RandomGen(..) , SplitGen(..) + , SeedGen(..) + -- ** Seed + , Seed(..) + , seedSize + , mkSeed + , unSeed + , nonEmptyToSeed + , nonEmptyFromSeed + -- * Stateful , StatefulGen(..) , FrozenGen(..) , ThawedGen(..) @@ -87,7 +101,7 @@ module System.Random.Internal import Control.Arrow import Control.DeepSeq (NFData) -import Control.Monad (replicateM, when, (>=>)) +import Control.Monad (guard, replicateM, when, (>=>)) import Control.Monad.Cont (ContT, runContT) import Control.Monad.Identity (IdentityT (runIdentityT)) import Control.Monad.ST @@ -99,6 +113,8 @@ import Data.ByteString.Short.Internal (ShortByteString(SBS), fromShort) import Data.IORef (IORef, newIORef) import Data.Int import Data.List (sortOn) +import Data.List.NonEmpty (NonEmpty(..), nonEmpty) +import Data.Proxy (Proxy(..)) import Data.Word import Foreign.C.Types import Foreign.Storable (Storable) @@ -106,6 +122,7 @@ import GHC.Exts import GHC.Generics import GHC.IO (IO(..)) import GHC.ST (ST(..)) +import GHC.TypeLits (Nat, KnownNat, natVal, type (<=)) import GHC.Word import Numeric.Natural (Natural) import System.IO.Unsafe (unsafePerformIO) @@ -131,7 +148,7 @@ import Data.ByteString (ByteString) -- @since 1.0.0 {-# DEPRECATED next "No longer used" #-} {-# DEPRECATED genRange "No longer used" #-} -class RandomGen g where +class SeedGen g => RandomGen g where {-# MINIMAL (genWord32|genWord64|(next,genRange)) #-} -- | Returns an 'Int' that is uniformly distributed over the range returned by -- 'genRange' (including both end points), and a new generator. Using 'next' @@ -276,11 +293,105 @@ class RandomGen g => SplitGen g where -- @since 1.3.0 splitGen :: g -> (g, g) +-- | This form of the pseudo-random is designed to be to restore from and easy to serialize. +-- +-- @since 1.3.0 +newtype Seed g = Seed ByteArray + deriving (Eq, Ord) + +-- | Get the expected size in bytes of the `Seed` +-- +-- @since 1.3.0 +seedSize :: forall g. SeedGen g => Int +seedSize = fromIntegral $ natVal (Proxy :: Proxy (SeedSize g)) + +-- | Construct a `Seed` from a `ByteArray` of expected length. Whenever `ByteArray` does +-- not match the `SeedSize` specified by the pseudo-random generator, this function will +-- return `Nothing` +-- +-- @since 1.3.0 +mkSeed :: forall g. SeedGen g => ByteArray -> Maybe (Seed g) +mkSeed ba = do + guard (sizeOfByteArray ba == seedSize @g) + Just $ Seed ba + +-- | Unwrap the `Seed` and get the underlying `ByteArray` +-- +-- @since 1.3.0 +unSeed :: Seed g -> ByteArray +unSeed (Seed ba) = ba + +nonEmptyToSeed :: forall g. SeedGen g => NonEmpty Word64 -> Seed g +nonEmptyToSeed xs = Seed $ runST $ do + let n = seedSize @g + mba <- newMutableByteArray n + _ <- flip runStateT (toList xs) $ do + defaultUnsafeFillMutableByteArrayT mba 0 n $ do + get >>= \case + [] -> pure 0 + w:ws -> w <$ put ws + freezeMutableByteArray mba + +nonEmptyFromSeed :: forall g. SeedGen g => Seed g -> NonEmpty Word64 +nonEmptyFromSeed (Seed ba) = + case nonEmpty $ reverse $ goWord64 0 [] of + Just ne -> ne + Nothing -> -- Seed is at least 1 byte in size, so it can't be empty + error $ "Impossible: Seed must be at least: " + ++ show (seedSize @g) + ++ " bytes, but got " + ++ show n + where + n = sizeOfByteArray ba + n8 = 8 * (n `quot` 8) + goWord64 i !acc + | i < n8 = goWord64 (i + 8) (indexWord64LE ba i : acc) + | i == n = acc + | otherwise = indexByteSliceWord64LE ba i n : acc + +-- | Interface for coverting a pure pseudo-random number generator to and from non-empty +-- sequence of bytes/words. Seeds are stored in Little-Endian order regardless of the platform +-- it is being used on, which provides inter-platform compatibility, while providing +-- optimal performance for most common platforms. +-- +-- Conversion to and from a `Seed` serves as a building block for implementing +-- serialization for any pure or frozen pseudo-random number generator +-- +-- @since 1.3.0 +class (KnownNat (SeedSize g), 1 <= SeedSize g) => SeedGen g where + type SeedSize g :: Nat + {-# MINIMAL (seedGen, unseedGen)|(seedGen64, unseedGen64) #-} + + -- | + -- + -- @since 1.3.0 + seedGen :: Seed g -> g + seedGen = seedGen64 . nonEmptyFromSeed + + -- | + -- + -- @since 1.3.0 + unseedGen :: g -> Seed g + unseedGen = nonEmptyToSeed . unseedGen64 + + -- | + -- + -- @since 1.3.0 + seedGen64 :: NonEmpty Word64 -> g + seedGen64 = seedGen . nonEmptyToSeed + + -- | + -- + -- @since 1.3.0 + unseedGen64 :: g -> NonEmpty Word64 + unseedGen64 = nonEmptyFromSeed . unseedGen + + -- | 'StatefulGen' is an interface to monadic pseudo-random number generators. -- -- @since 1.2.0 class Monad m => StatefulGen g m where - {-# MINIMAL (uniformWord32|uniformWord64) #-} + {-# MINIMAL uniformWord32|uniformWord64 #-} -- | @uniformWord32R upperBound g@ generates a 'Word32' that is uniformly -- distributed over the range @[0, upperBound]@. -- @@ -392,7 +503,7 @@ class Monad m => StatefulGen g m where -- @ -- -- @since 1.2.0 -class StatefulGen (MutableGen f m) m => FrozenGen f m where +class (SeedGen f, StatefulGen (MutableGen f m) m) => FrozenGen f m where {-# MINIMAL (modifyGen|(freezeGen,overwriteGen)) #-} -- | Represents the state of the pseudo-random number generator for use with -- 'thawGen' and 'freezeGen'. @@ -492,7 +603,7 @@ genByteArrayST isPinned n0 action = do mba <- if isPinned then newPinnedMutableByteArray n else newMutableByteArray n - runIdentityT $ defaultUnsafeUniformFillMutableByteArrayT mba 0 n (lift action) + runIdentityT $ defaultUnsafeFillMutableByteArrayT mba 0 n (lift action) freezeMutableByteArray mba {-# INLINE genByteArrayST #-} @@ -520,14 +631,14 @@ uniformFillMutableByteArray mba i0 n g = do unsafeUniformFillMutableByteArray mba offset numBytes g {-# INLINE uniformFillMutableByteArray #-} -defaultUnsafeUniformFillMutableByteArrayT :: +defaultUnsafeFillMutableByteArrayT :: (Monad (t (ST s)), MonadTrans t) => MutableByteArray s -> Int -> Int -> t (ST s) Word64 -> t (ST s) () -defaultUnsafeUniformFillMutableByteArrayT mba offset n gen64 = do +defaultUnsafeFillMutableByteArrayT mba offset n gen64 = do let !n64 = n `quot` 8 !endIx64 = offset + n64 * 8 !nrem = n `rem` 8 @@ -547,14 +658,14 @@ defaultUnsafeUniformFillMutableByteArrayT mba offset n gen64 = do -- still need using smaller generators (eg. uniformWord8), but that would -- result in inconsistent tail when total length is slightly varied. lift $ writeByteSliceWord64LE mba (endIx - nrem) endIx w64 -{-# INLINEABLE defaultUnsafeUniformFillMutableByteArrayT #-} -{-# SPECIALIZE defaultUnsafeUniformFillMutableByteArrayT +{-# INLINEABLE defaultUnsafeFillMutableByteArrayT #-} +{-# SPECIALIZE defaultUnsafeFillMutableByteArrayT :: MutableByteArray s -> Int -> Int -> IdentityT (ST s) Word64 -> IdentityT (ST s) () #-} -{-# SPECIALIZE defaultUnsafeUniformFillMutableByteArrayT +{-# SPECIALIZE defaultUnsafeFillMutableByteArrayT :: MutableByteArray s -> Int -> Int @@ -574,7 +685,7 @@ defaultUnsafeUniformFillMutableByteArray :: -> ST s g defaultUnsafeUniformFillMutableByteArray mba i0 n g = flip execStateT g - $ defaultUnsafeUniformFillMutableByteArrayT mba i0 n (state genWord64) + $ defaultUnsafeFillMutableByteArrayT mba i0 n (state genWord64) {-# INLINE defaultUnsafeUniformFillMutableByteArray #-} @@ -590,6 +701,9 @@ uniformByteString n g = -- Architecture independent helpers: +sizeOfByteArray :: ByteArray -> Int +sizeOfByteArray (ByteArray ba#) = I# (sizeofByteArray# ba#) + st_ :: (State# s -> State# s) -> ST s () st_ m# = ST $ \s# -> (# m# s#, () #) {-# INLINE st_ #-} @@ -631,12 +745,54 @@ writeByteSliceWord64LE mba fromByteIx toByteIx = go fromByteIx go (i + 1) (z `shiftR` 8) {-# INLINE writeByteSliceWord64LE #-} +indexWord8 :: + ByteArray + -> Int -- ^ Offset into immutable byte array in number of bytes + -> Word8 +indexWord8 (ByteArray ba#) (I# i#) = + W8# (indexWord8Array# ba# i#) +{-# INLINE indexWord8 #-} + +indexWord64LE :: + ByteArray + -> Int -- ^ Offset into immutable byte array in number of bytes + -> Word64 +#if defined WORDS_BIGENDIAN || !(__GLASGOW_HASKELL__ >= 806) +indexWord64LE ba i = indexByteSliceWord64LE ba i (i + 8) +#else +indexWord64LE (ByteArray ba#) (I# i#) + | wordSizeInBits == 64 = W64# (indexWord8ArrayAsWord64# ba# i#) + | otherwise = + let !w32l = W32# (indexWord8ArrayAsWord32# ba# i#) + !w32u = W32# (indexWord8ArrayAsWord32# ba# (i# +# 4#)) + in (fromIntegral w32u `shiftL` 32) .|. fromIntegral w32l +#endif +{-# INLINE indexWord64LE #-} + +indexByteSliceWord64LE :: + ByteArray + -> Int -- ^ Starting offset in number of bytes + -> Int -- ^ Ending offset in number of bytes + -> Word64 +indexByteSliceWord64LE ba fromByteIx toByteIx = goWord8 fromByteIx 0 + where + r = (toByteIx - fromByteIx) `rem` 8 + nPadBits = if r == 0 then 0 else 8 * (8 - r) + goWord8 i !w64 + | i < toByteIx = goWord8 (i + 1) (shiftL w64 8 .|. fromIntegral (indexWord8 ba i)) + | otherwise = byteSwap64 (shiftL w64 nPadBits) +{-# INLINE indexByteSliceWord64LE #-} + -- On big endian machines we need to write one byte at a time for consistency with little -- endian machines. Also for GHC versions prior to 8.6 we don't have primops that can -- write with byte offset, eg. writeWord8ArrayAsWord64# and writeWord8ArrayAsWord32#, so we -- also must fallback to writing one byte a time. Such fallback results in about 3 times -- slow down, which is not the end of the world. -writeWord64LE :: MutableByteArray s -> Int -> Word64 -> ST s () +writeWord64LE :: + MutableByteArray s + -> Int -- ^ Offset into mutable byte array in number of bytes + -> Word64 -- ^ 8 bytes that will be written into the supplied array + -> ST s () #if defined WORDS_BIGENDIAN || !(__GLASGOW_HASKELL__ >= 806) writeWord64LE mba i w64 = writeByteSliceWord64LE mba i (i + 8) w64 @@ -736,7 +892,7 @@ data StateGenM g = StateGenM -- -- @since 1.2.0 newtype StateGen g = StateGen { unStateGen :: g } - deriving (Eq, Ord, Show, RandomGen, Storable, NFData) + deriving (Eq, Ord, Show, SeedGen, RandomGen, Storable, NFData) instance (RandomGen g, MonadState g m) => StatefulGen (StateGenM g) m where uniformWord32R r _ = state (genWord32R r) @@ -892,11 +1048,24 @@ shuffleListM xs gen = do -- | The standard pseudo-random number generator. newtype StdGen = StdGen { unStdGen :: SM.SMGen } - deriving (Show, RandomGen, SplitGen, NFData) + deriving (Show, SeedGen, RandomGen, SplitGen, NFData) instance Eq StdGen where StdGen x1 == StdGen x2 = SM.unseedSMGen x1 == SM.unseedSMGen x2 + +instance SeedGen SM.SMGen where + type SeedSize SM.SMGen = 16 + seedGen (Seed ba) = + SM.seedSMGen (indexWord64LE ba 0) (indexWord64LE ba 8) + unseedGen g = + case SM.unseedSMGen g of + (seed, gamma) -> Seed $ runST $ do + mba <- newMutableByteArray 16 + writeWord64LE mba 0 seed + writeWord64LE mba 8 gamma + freezeMutableByteArray mba + instance RandomGen SM.SMGen where next = SM.nextInt {-# INLINE next #-} @@ -914,6 +1083,24 @@ instance SplitGen SM.SMGen where splitGen = SM.splitSMGen {-# INLINE splitGen #-} +instance SeedGen SM32.SMGen where + type SeedSize SM32.SMGen = 8 + seedGen (Seed ba) = + let x = indexWord64LE ba 0 + seed, gamma :: Word32 + seed = fromIntegral (shiftR x 32) + gamma = fromIntegral x + in SM32.seedSMGen seed gamma + unseedGen g = + let seed, gamma :: Word32 + (seed, gamma) = SM32.unseedSMGen g + in Seed $ runST $ do + mba <- newMutableByteArray 8 + let w64 :: Word64 + w64 = shiftL (fromIntegral seed) 32 .|. fromIntegral gamma + writeWord64LE mba 0 w64 + freezeMutableByteArray mba + instance RandomGen SM32.SMGen where next = SM32.nextInt {-# INLINE next #-} diff --git a/src/System/Random/Stateful.hs b/src/System/Random/Stateful.hs index c89b02b3..ea1ef885 100644 --- a/src/System/Random/Stateful.hs +++ b/src/System/Random/Stateful.hs @@ -7,7 +7,7 @@ {-# LANGUAGE Trustworthy #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UndecidableInstances #-} - +{-# LANGUAGE UndecidableSuperClasses #-} -- | -- Module : System.Random.Stateful -- Copyright : (c) The University of Glasgow 2001 @@ -351,7 +351,7 @@ newtype AtomicGenM g = AtomicGenM { unAtomicGenM :: IORef g} -- -- @since 1.2.0 newtype AtomicGen g = AtomicGen { unAtomicGen :: g} - deriving (Eq, Ord, Show, RandomGen, SplitGen, Storable, NFData) + deriving (Eq, Ord, Show, SeedGen, RandomGen, SplitGen, Storable, NFData) -- | Creates a new 'AtomicGenM'. -- @@ -442,7 +442,7 @@ newtype IOGenM g = IOGenM { unIOGenM :: IORef g } -- -- @since 1.2.0 newtype IOGen g = IOGen { unIOGen :: g } - deriving (Eq, Ord, Show, RandomGen, SplitGen, Storable, NFData) + deriving (Eq, Ord, Show, SeedGen, RandomGen, SplitGen, Storable, NFData) -- | Creates a new 'IOGenM'. @@ -513,7 +513,7 @@ newtype STGenM g s = STGenM { unSTGenM :: STRef s g } -- -- @since 1.2.0 newtype STGen g = STGen { unSTGen :: g } - deriving (Eq, Ord, Show, RandomGen, SplitGen, Storable, NFData) + deriving (Eq, Ord, Show, SeedGen, RandomGen, SplitGen, Storable, NFData) -- | Creates a new 'STGenM'. -- @@ -608,7 +608,7 @@ newtype TGenM g = TGenM { unTGenM :: TVar g } -- -- @since 1.2.1 newtype TGen g = TGen { unTGen :: g } - deriving (Eq, Ord, Show, RandomGen, SplitGen, Storable, NFData) + deriving (Eq, Ord, Show, SeedGen, RandomGen, SplitGen, Storable, NFData) -- | Creates a new 'TGenM' in `STM`. -- diff --git a/test/Spec.hs b/test/Spec.hs index c6234193..8c5ca258 100644 --- a/test/Spec.hs +++ b/test/Spec.hs @@ -1,11 +1,13 @@ {-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE CPP #-} +{-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveAnyClass #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeFamilies #-} module Main (main) where import Control.Monad (replicateM, forM_) @@ -14,6 +16,7 @@ import qualified Data.ByteString as BS import qualified Data.ByteString.Short as SBS import Data.Int import Data.List (sortOn) +import Data.List.NonEmpty (NonEmpty(..)) import Data.Typeable import Data.Void import Data.Word @@ -297,6 +300,11 @@ instance Monad m => Serial m Foo newtype ConstGen = ConstGen Word64 +instance SeedGen ConstGen where + type SeedSize ConstGen = 8 + seedGen64 (w :| _) = ConstGen w + unseedGen64 (ConstGen w) = pure w + instance RandomGen ConstGen where genWord64 g@(ConstGen c) = (c, g) instance SplitGen ConstGen where diff --git a/test/Spec/Stateful.hs b/test/Spec/Stateful.hs index d0a64e4d..b5e6fb18 100644 --- a/test/Spec/Stateful.hs +++ b/test/Spec/Stateful.hs @@ -174,7 +174,7 @@ frozenGenSpecFor fromStdGen toStdGen runStatefulGen = toStdGen runStatefulGen , testProperty "uniformByteArrayM/genByteArray" $ - forAll $ \(NonNegative n', isPinned1, isPinned2) -> + forAll $ \(NonNegative n', isPinned1 :: Bool, isPinned2 :: Bool) -> let n = n' `mod` 100000 -- Ensure it is not too big in matchRandomGenSpec (uniformByteArrayM isPinned1 n)