diff --git a/random.cabal b/random.cabal index 47a87c0f..ddcef7d5 100644 --- a/random.cabal +++ b/random.cabal @@ -100,6 +100,8 @@ library deepseq >=1.1 && <2, mtl >=2.2 && <2.4, splitmix >=0.1 && <0.2 + if impl(ghc < 9.4) + build-depends: data-array-byte test-suite legacy-test type: exitcode-stdio-1.0 diff --git a/src/System/Random.hs b/src/System/Random.hs index 92f31faa..695cb63f 100644 --- a/src/System/Random.hs +++ b/src/System/Random.hs @@ -28,6 +28,10 @@ module System.Random , Uniform , UniformRange , Finite + -- * Generators for sequences of pseudo-random bytes + , uniformByteArray + , uniformByteString + , uniformFillMutableByteArray -- ** Standard pseudo-random number generator , StdGen diff --git a/src/System/Random/Internal.hs b/src/System/Random/Internal.hs index 6724ad34..1d247258 100644 --- a/src/System/Random/Internal.hs +++ b/src/System/Random/Internal.hs @@ -55,6 +55,7 @@ module System.Random.Internal , uniformViaFiniteM , UniformRange(..) , uniformByteStringM + , uniformShortByteStringM , uniformDouble01M , uniformDoublePositive01M , uniformFloat01M @@ -63,19 +64,24 @@ module System.Random.Internal , uniformEnumRM -- * Generators for sequences of pseudo-random bytes + , uniformByteArray + , uniformFillMutableByteArray + , uniformByteString + , genByteArrayST , genShortByteStringIO , genShortByteStringST + , defaultUnsafeUniformFillMutableByteArray ) where import Control.Arrow import Control.DeepSeq (NFData) import Control.Monad (when, (>=>)) import Control.Monad.Cont (ContT, runContT) -import Control.Monad.IO.Class (MonadIO(..)) +import Control.Monad.Identity (runIdentityT) import Control.Monad.ST -import Control.Monad.ST.Unsafe -import Control.Monad.State.Strict (MonadState(..), State, StateT(..), runState) -import Control.Monad.Trans (lift) +import Control.Monad.State.Strict (MonadState(..), State, StateT(..), execStateT, runState) +import Control.Monad.Trans (lift, MonadTrans) +import Data.Array.Byte (ByteArray(..), MutableByteArray(..)) import Data.Bits import Data.ByteString.Short.Internal (ShortByteString(SBS), fromShort) import Data.IORef (IORef, newIORef) @@ -86,6 +92,7 @@ import Foreign.Storable (Storable) import GHC.Exts import GHC.Generics import GHC.IO (IO(..)) +import GHC.ST (ST(..)) import GHC.Word import Numeric.Natural (Natural) import System.IO.Unsafe (unsafePerformIO) @@ -178,15 +185,31 @@ class RandomGen g where genWord64R m g = runStateGen g (unsignedBitmaskWithRejectionM uniformWord64 m) {-# INLINE genWord64R #-} - -- | @genShortByteString n g@ returns a 'ShortByteString' of length @n@ - -- filled with pseudo-random bytes. + -- | Same as `uniformByteArray`, but for `ShortByteString`. @genShortByteString n g@ returns + -- a 'ShortByteString' of length @n@ filled with pseudo-random bytes. -- -- @since 1.2.0 genShortByteString :: Int -> g -> (ShortByteString, g) genShortByteString n g = - unsafePerformIO $ runStateGenT g (genShortByteStringIO n . uniformWord64) + case uniformByteArray False n g of + (ByteArray ba#, g') -> (SBS ba#, g') {-# INLINE genShortByteString #-} + unsafeUniformFillMutableByteArray :: + MutableByteArray s + -- ^ Mutable array to fill with random bytes + -> Int + -- ^ Offset into a mutable array from the beginning in number of bytes. Offset must + -- be non-negative, but this will not be checked + -> Int + -- ^ Number of randomly generated bytes to write into the array. Number of bytes + -- must be non-negative and less then the total size of the array, minus the + -- offset. This also will be checked. + -> g + -> ST s g + unsafeUniformFillMutableByteArray = defaultUnsafeUniformFillMutableByteArray + {-# INLINE unsafeUniformFillMutableByteArray #-} + -- | Yields the range of values returned by 'next'. -- -- It is required that: @@ -277,14 +300,28 @@ class Monad m => StatefulGen g m where pure (shiftL (fromIntegral h32) 32 .|. fromIntegral l32) {-# INLINE uniformWord64 #-} + -- | @uniformByteArrayM n g@ generates a 'ByteArray' of length @n@ + -- filled with pseudo-random bytes. + -- + -- @since 1.3.0 + uniformByteArrayM :: + Bool -- ^ Should `ByteArray` be allocated as pinned memory or not + -> Int -- ^ Size of the newly created `ByteArray` in number of bytes. + -> g -- ^ Generator to use for filling in the newly created `ByteArray` + -> m ByteArray + default uniformByteArrayM :: + (RandomGen f, FrozenGen f m, g ~ MutableGen f m) => Bool -> Int -> g -> m ByteArray + uniformByteArrayM isPinned n g = modifyGen g (uniformByteArray isPinned n) + {-# INLINE uniformByteArrayM #-} + -- | @uniformShortByteString n g@ generates a 'ShortByteString' of length @n@ -- filled with pseudo-random bytes. -- -- @since 1.2.0 uniformShortByteString :: Int -> g -> m ShortByteString - default uniformShortByteString :: MonadIO m => Int -> g -> m ShortByteString - uniformShortByteString n = genShortByteStringIO n . uniformWord64 + uniformShortByteString = uniformShortByteStringM {-# INLINE uniformShortByteString #-} +{-# DEPRECATED uniformShortByteString "In favor of `uniformShortByteStringM`" #-} -- | This class is designed for mutable pseudo-random number generators that have a frozen @@ -382,57 +419,154 @@ splitGen = flip modifyGen split splitMutableGen :: (RandomGen f, ThawedGen f m) => MutableGen f m -> m (MutableGen f m) splitMutableGen = splitGen >=> thawGen - -data MBA = MBA (MutableByteArray# RealWorld) - - -- | Efficiently generates a sequence of pseudo-random bytes in a platform -- independent manner. -- --- @since 1.2.0 -genShortByteStringIO :: - MonadIO m - => Int -- ^ Number of bytes to generate - -> m Word64 -- ^ IO action that can generate 8 random bytes at a time - -> m ShortByteString -genShortByteStringIO n0 gen64 = do - let !n@(I# n#) = max 0 n0 - !n64 = n `quot` 8 +-- @since 1.3.0 +uniformByteArray :: + RandomGen g + => Bool -- ^ Should byte array be allocted in pinned or unpinned memory. + -> Int -- ^ Number of bytes to generate + -> g -- ^ Pure pseudo-random numer generator + -> (ByteArray, g) +uniformByteArray isPinned n0 g = + runST $ do + let !n = max 0 n0 + mba <- + if isPinned + then newPinnedMutableByteArray n + else newMutableByteArray n + g' <- unsafeUniformFillMutableByteArray mba 0 n g + ba <- freezeMutableByteArray mba + pure (ba, g') +{-# INLINE uniformByteArray #-} + +-- | Using an `ST` action that generates 8 bytes at a type fill in a new `ByteArray` in +-- architecture agnostic manner. +-- +-- @since 1.3.0 +genByteArrayST :: Bool -> Int -> ST s Word64 -> ST s ByteArray +genByteArrayST isPinned n0 action = do + let !n = max 0 n0 + mba <- if isPinned + then newPinnedMutableByteArray n + else newMutableByteArray n + runIdentityT $ defaultUnsafeUniformFillMutableByteArrayT mba 0 n (lift action) + freezeMutableByteArray mba +{-# INLINE genByteArrayST #-} + +-- | Fill in a slice of a mutable byte array with randomly generated bytes. This +-- function does not fail. +-- +-- @since 1.3.0 +uniformFillMutableByteArray :: + RandomGen g + => MutableByteArray s + -- ^ Mutable array to fill with random bytes + -> Int + -- ^ Offset into a mutable array from the beginning in number of bytes. Whenever + -- negative offset is supplied it will be treaded as 0 + -> Int + -- ^ Number of randomly generated bytes to write into the array. Supplied number of + -- bytes will be clamped between 0 and the total size of the array, minus the + -- offset. + -> g + -> ST s g +uniformFillMutableByteArray mba i0 n g = do + sz <- getSizeOfMutableByteArray mba + let !offset = max (min 0 i0) sz + !numBytes = max (min 0 n) (sz - offset) + unsafeUniformFillMutableByteArray mba offset numBytes g +{-# INLINE uniformFillMutableByteArray #-} + +defaultUnsafeUniformFillMutableByteArrayT :: + (Monad (t (ST s)), MonadTrans t) + => MutableByteArray s + -> Int + -> Int + -> t (ST s) Word64 + -> t (ST s) () +defaultUnsafeUniformFillMutableByteArrayT mba i0 n gen64 = do + let !n64 = n `quot` 8 !nrem = n `rem` 8 - mba@(MBA mba#) <- - liftIO $ IO $ \s# -> - case newByteArray# n# s# of - (# s'#, mba# #) -> (# s'#, MBA mba# #) let go i = when (i < n64) $ do w64 <- gen64 -- Writing 8 bytes at a time in a Little-endian order gives us -- platform portability - liftIO $ writeWord64LE mba i w64 + lift $ writeWord64LE mba i w64 go (i + 1) - go 0 + go i0 when (nrem > 0) $ do w64 <- gen64 -- In order to not mess up the byte order we write 1 byte at a time in -- Little endian order. It is tempting to simply generate as many bytes as we -- still need using smaller generators (eg. uniformWord8), but that would -- result in inconsistent tail when total length is slightly varied. - liftIO $ writeByteSliceWord64LE mba (n - nrem) n w64 - liftIO $ IO $ \s# -> - case unsafeFreezeByteArray# mba# s# of - (# s'#, ba# #) -> (# s'#, SBS ba# #) -{-# INLINE genShortByteStringIO #-} + lift $ writeByteSliceWord64LE mba (n - nrem) n w64 +{-# INLINE defaultUnsafeUniformFillMutableByteArrayT #-} + +-- | Efficiently generates a sequence of pseudo-random bytes in a platform +-- independent manner. +-- +-- @since 1.2.0 +defaultUnsafeUniformFillMutableByteArray :: + RandomGen g + => MutableByteArray s + -> Int -- ^ Starting offset + -> Int -- ^ Number of random bytes to write into the array + -> g -- ^ ST action that can generate 8 random bytes at a time + -> ST s g +defaultUnsafeUniformFillMutableByteArray mba i0 n g = + flip execStateT g + $ defaultUnsafeUniformFillMutableByteArrayT mba i0 n (state genWord64) +{-# INLINE defaultUnsafeUniformFillMutableByteArray #-} + + +-- | Generates a pseudo-random 'ByteString' of the specified size. +-- +-- @since 1.3.0 +uniformByteString :: RandomGen g => Int -> g -> (ByteString, g) +uniformByteString n g = + case uniformByteArray True n g of + (byteArray, g') -> + (shortByteStringToByteString $ byteArrayToShortByteString byteArray, g') -- Architecture independent helpers: -io_ :: (State# RealWorld -> State# RealWorld) -> IO () -io_ m# = IO $ \s# -> (# m# s#, () #) -{-# INLINE io_ #-} -writeWord8 :: MBA -> Int -> Word8 -> IO () -writeWord8 (MBA mba#) (I# i#) (W8# w#) = io_ (writeWord8Array# mba# i# w#) +st_ :: (State# s -> State# s) -> ST s () +st_ m# = ST $ \s# -> (# m# s#, () #) +{-# INLINE st_ #-} + +ioToST :: IO a -> ST RealWorld a +ioToST (IO m#) = ST m# +{-# INLINE ioToST #-} + +newMutableByteArray :: Int -> ST s (MutableByteArray s) +newMutableByteArray (I# n#) = + ST $ \s# -> + case newByteArray# n# s# of + (# s'#, mba# #) -> (# s'#, MutableByteArray mba# #) +{-# INLINE newMutableByteArray #-} + +newPinnedMutableByteArray :: Int -> ST s (MutableByteArray s) +newPinnedMutableByteArray (I# n#) = + ST $ \s# -> + case newPinnedByteArray# n# s# of + (# s'#, mba# #) -> (# s'#, MutableByteArray mba# #) +{-# INLINE newPinnedMutableByteArray #-} + +freezeMutableByteArray :: MutableByteArray s -> ST s ByteArray +freezeMutableByteArray (MutableByteArray mba#) = + ST $ \s# -> + case unsafeFreezeByteArray# mba# s# of + (# s'#, ba# #) -> (# s'#, ByteArray ba# #) + +writeWord8 :: MutableByteArray s -> Int -> Word8 -> ST s () +writeWord8 (MutableByteArray mba#) (I# i#) (W8# w#) = st_ (writeWord8Array# mba# i# w#) {-# INLINE writeWord8 #-} -writeByteSliceWord64LE :: MBA -> Int -> Int -> Word64 -> IO () +writeByteSliceWord64LE :: MutableByteArray s -> Int -> Int -> Word64 -> ST s () writeByteSliceWord64LE mba fromByteIx toByteIx = go fromByteIx where go !i !z = @@ -441,47 +575,48 @@ writeByteSliceWord64LE mba fromByteIx toByteIx = go fromByteIx go (i + 1) (z `shiftR` 8) {-# INLINE writeByteSliceWord64LE #-} -writeWord64LE :: MBA -> Int -> Word64 -> IO () +writeWord64LE :: MutableByteArray s -> Int -> Word64 -> ST s () #ifdef WORDS_BIGENDIAN writeWord64LE mba i w64 = do let !i8 = i * 8 writeByteSliceWord64LE mba i8 (i8 + 8) w64 #else -writeWord64LE (MBA mba#) (I# i#) w64@(W64# w64#) - | wordSizeInBits == 64 = io_ (writeWord64Array# mba# i# w64#) +writeWord64LE (MutableByteArray mba#) (I# i#) w64@(W64# w64#) + | wordSizeInBits == 64 = st_ (writeWord64Array# mba# i# w64#) | otherwise = do let !i32# = i# *# 2# !(W32# w32l#) = fromIntegral w64 !(W32# w32u#) = fromIntegral (w64 `shiftR` 32) - io_ (writeWord32Array# mba# i32# w32l#) - io_ (writeWord32Array# mba# (i32# +# 1#) w32u#) + st_ (writeWord32Array# mba# i32# w32l#) + st_ (writeWord32Array# mba# (i32# +# 1#) w32u#) #endif {-# INLINE writeWord64LE #-} +getSizeOfMutableByteArray :: MutableByteArray s -> ST s Int +getSizeOfMutableByteArray (MutableByteArray mba#) = +#if __GLASGOW_HASKELL__ >=802 + ST $ \s -> + case getSizeofMutableByteArray# mba# s of + (# s', n# #) -> (# s', I# n# #) +#else + pure $! (I# sizeofMutableByteArray# mba#) +#endif +{-# INLINE getSizeOfMutableByteArray #-} --- | Same as 'genShortByteStringIO', but runs in 'ST'. --- --- @since 1.2.0 -genShortByteStringST :: Int -> ST s Word64 -> ST s ShortByteString -genShortByteStringST n action = - unsafeIOToST (genShortByteStringIO n (unsafeSTToIO action)) -{-# INLINE genShortByteStringST #-} - +byteArrayToShortByteString :: ByteArray -> ShortByteString +byteArrayToShortByteString (ByteArray ba#) = SBS ba# --- | Generates a pseudo-random 'ByteString' of the specified size. --- --- @since 1.2.0 -uniformByteStringM :: StatefulGen g m => Int -> g -> m ByteString -uniformByteStringM n g = do - ba <- uniformShortByteString n g - pure $ +-- | Convert a ShortByteString to ByteString by casting, whenever memory is pinned, +-- otherwise make a copy into a new pinned ByteString +shortByteStringToByteString :: ShortByteString -> ByteString +shortByteStringToByteString ba = #if __GLASGOW_HASKELL__ < 802 - fromShort ba + fromShort ba #else - let !(SBS ba#) = ba in - if isTrue# (isByteArrayPinned# ba#) - then pinnedByteArrayToByteString ba# - else fromShort ba + let !(SBS ba#) = ba in + if isTrue# (isByteArrayPinned# ba#) + then pinnedByteArrayToByteString ba# + else fromShort ba {-# INLINE uniformByteStringM #-} pinnedByteArrayToByteString :: ByteArray# -> ByteString @@ -495,6 +630,39 @@ pinnedByteArrayToForeignPtr ba# = {-# INLINE pinnedByteArrayToForeignPtr #-} #endif +-- | Same as 'genShortByteStringIO', but runs in 'ST'. +-- +-- @since 1.2.0 +genShortByteStringST :: Int -> ST s Word64 -> ST s ShortByteString +genShortByteStringST n0 action = byteArrayToShortByteString <$> genByteArrayST False n0 action +{-# INLINE genShortByteStringST #-} + +-- | Efficiently fills in a new `ShortByteString` in a platform independent manner. +-- +-- @since 1.2.0 +genShortByteStringIO :: + Int -- ^ Number of bytes to generate + -> IO Word64 -- ^ IO action that can generate 8 random bytes at a time + -> IO ShortByteString +genShortByteStringIO n ioAction = stToIO $ genShortByteStringST n (ioToST ioAction) +{-# INLINE genShortByteStringIO #-} + +-- | @uniformShortByteString n g@ generates a 'ShortByteString' of length @n@ +-- filled with pseudo-random bytes. +-- +-- @since 1.3.0 +uniformShortByteStringM :: StatefulGen g m => Int -> g -> m ShortByteString +uniformShortByteStringM n g = byteArrayToShortByteString <$> uniformByteArrayM False n g +{-# INLINE uniformShortByteStringM #-} + +-- | Generates a pseudo-random 'ByteString' of the specified size. +-- +-- @since 1.2.0 +uniformByteStringM :: StatefulGen g m => Int -> g -> m ByteString +uniformByteStringM n g = + shortByteStringToByteString . byteArrayToShortByteString + <$> uniformByteArrayM True n g + -- | Opaque data type that carries the type of a pure pseudo-random number -- generator. diff --git a/src/System/Random/Stateful.hs b/src/System/Random/Stateful.hs index 0e42cc3c..2e1e4845 100644 --- a/src/System/Random/Stateful.hs +++ b/src/System/Random/Stateful.hs @@ -28,7 +28,15 @@ module System.Random.Stateful -- * Mutable pseudo-random number generator interfaces -- $interfaces - , StatefulGen(..) + , StatefulGen + ( uniformWord32R + , uniformWord64R + , uniformWord8 + , uniformWord16 + , uniformWord32 + , uniformWord64 + , uniformShortByteString + ) , FrozenGen(..) , ThawedGen(..) , withMutableGen @@ -45,7 +53,7 @@ module System.Random.Stateful -- * Monadic adapters for pure pseudo-random number generators #monadicadapters# -- $monadicadapters - -- ** Pure adapter + -- ** Pure adapter in 'StateT' , StateGen(..) , StateGenM(..) , runStateGen @@ -86,14 +94,23 @@ module System.Random.Stateful , uniformViaFiniteM , UniformRange(..) - -- * Generators for sequences of pseudo-random bytes + -- ** Generators for sequences of pseudo-random bytes + , uniformByteArrayM + , uniformByteStringM + , uniformShortByteStringM + + -- * Helper functions for createing instances + -- ** Sequences of bytes + , genByteArrayST , genShortByteStringIO , genShortByteStringST - , uniformByteStringM + , defaultUnsafeUniformFillMutableByteArray + -- ** Floating point numbers , uniformDouble01M , uniformDoublePositive01M , uniformFloat01M , uniformFloatPositive01M + -- ** Enum types , uniformEnumM , uniformEnumRM diff --git a/stack-oldish.yaml b/stack-oldish.yaml new file mode 100644 index 00000000..06ec8886 --- /dev/null +++ b/stack-oldish.yaml @@ -0,0 +1,5 @@ +resolver: lts-18.28 +packages: +- . +extra-deps: +- data-array-byte-0.1.0.1@sha256:2ef1bd3511e82ba56f7f23cd793dd2da84338a1e7c2cbea5b151417afe3baada,1989 diff --git a/stack.yaml b/stack.yaml index d15416fa..7f9dc000 100644 --- a/stack.yaml +++ b/stack.yaml @@ -1,4 +1,4 @@ -resolver: lts-18.28 +resolver: lts-21.17 packages: - . extra-deps: [] diff --git a/test/Spec.hs b/test/Spec.hs index 078d4d0a..e9ac6874 100644 --- a/test/Spec.hs +++ b/test/Spec.hs @@ -125,7 +125,7 @@ byteStringSpec = let g = mkStdGen 2021 bs = [78,232,117,189,13,237,63,84,228,82,19,36,191,5,128,192] :: [Word8] forM_ [0 .. length bs - 1] $ \ n -> do - xs <- SBS.unpack <$> runStateGenT_ g (uniformShortByteString n) + xs <- SBS.unpack <$> runStateGenT_ g (uniformShortByteStringM n) xs @?= take n bs ys <- BS.unpack <$> runStateGenT_ g (uniformByteStringM n) ys @?= xs diff --git a/test/Spec/Stateful.hs b/test/Spec/Stateful.hs index dbed18c4..e6b48197 100644 --- a/test/Spec/Stateful.hs +++ b/test/Spec/Stateful.hs @@ -155,11 +155,11 @@ frozenGenSpecFor fromStdGen toStdGen runStatefulGen = , testProperty "uniformWord64R/genWord64R" $ forAll $ \w64 -> matchRandomGenSpec (uniformWord64R w64) (genWord64R w64) fromStdGen toStdGen runStatefulGen - , testProperty "uniformShortByteString/genShortByteString" $ + , testProperty "uniformShortByteStringM/genShortByteString" $ forAll $ \(NonNegative n') -> let n = n' `mod` 100000 -- Ensure it is not too big in matchRandomGenSpec - (uniformShortByteString n) + (uniformShortByteStringM n) (genShortByteString n) fromStdGen toStdGen