diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index d965adb4..7286ae75 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -108,7 +108,7 @@ jobs: stack-yaml: stack-old.yaml - resolver: lts-18 ghc: '8.10.7' - stack-yaml: stack.yaml + stack-yaml: stack.lts-18.yaml - resolver: lts-19 ghc: '9.0.2' stack-yaml: stack-coveralls.yaml @@ -271,8 +271,10 @@ jobs: githubToken: ${{ github.token }} install: | apt-get update -y - apt-get install -y ghc libghc-tasty-smallcheck-dev libghc-tasty-hunit-dev libghc-splitmix-dev curl + apt-get install -y git ghc libghc-tasty-smallcheck-dev libghc-tasty-hunit-dev libghc-splitmix-dev curl run: | + git clone https://github.com/Bodigrim/data-array-byte + cp -r data-array-byte/Data . ghc --version ghc --make -isrc:test-legacy -o legacy test-legacy/Legacy.hs ./legacy diff --git a/CHANGELOG.md b/CHANGELOG.md index bc2cda35..412a7b5a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,11 +1,27 @@ # 1.3.0 -* Move `thawGen` from `FreezeGen` into the new `ThawGen` type class. Fixes an issue with - an unlawful instance of `StateGen` for `FreezeGen`. -* Add `modifyGen` and `overwriteGen` to the `FrozenGen` type class -* Add `splitGen` and `splitMutableGen` -* Switch `randomM` and `randomRM` to use `FrozenGen` instead of `RandomGenM` -* Deprecate `RandomGenM` in favor of a more powerful `FrozenGen` +* Add compatibility with recently added `ByteArray` to `base`: + [#153](https://github.com/haskell/random/pull/153) + * Switch to using `ByteArray` for type class implementation instead of + `ShortByteString` + * Add `unsafeUniformFillMutableByteArray` to `RandomGen` and a helper function + `defaultUnsafeUniformFillMutableByteArray` that makes implementation + for most instances easier. + * Add `uniformByteArray`, `uniformByteString` and `uniformFillMutableByteArray` + * Add `uniformByteArrayM` to `StatefulGen` + * Add `uniformByteStringM` and `uniformShortByteStringM` + * Deprecate `uniformShortByteString` in favor of `uniformShortByteStringM` for + consistent naming and a future plan of removing it from `StatefulGen` + type class + * Expose a helper function `genByteArrayST`, that can be used for + defining implementation for `uniformByteArrayM` +* Improve `FrozenGen` interface: [#149](https://github.com/haskell/random/pull/149) + * Move `thawGen` from `FreezeGen` into the new `ThawGen` type class. Fixes an issue with + an unlawful instance of `StateGen` for `FreezeGen`. + * Add `modifyGen` and `overwriteGen` to the `FrozenGen` type class + * Add `splitGen` and `splitMutableGen` + * Switch `randomM` and `randomRM` to use `FrozenGen` instead of `RandomGenM` + * Deprecate `RandomGenM` in favor of a more powerful `FrozenGen` * Add `isInRange` to `UniformRange`: [#78](https://github.com/haskell/random/pull/78) * Add default implementation for `uniformRM` using `Generics`: [#92](https://github.com/haskell/random/pull/92) diff --git a/bench/Main.hs b/bench/Main.hs index 42fb04f7..cb0c5ba0 100644 --- a/bench/Main.hs +++ b/bench/Main.hs @@ -25,6 +25,8 @@ seed = 1337 main :: IO () main = do let !sz = 100000 + !sz100MiB = 100 * 1024 * 1024 + genLengths :: ([Int], StdGen) genLengths = -- create 5000 small lengths that are needed for ShortByteString generation runStateGen (mkStdGen 2020) $ \g -> replicateM 5000 (uniformRM (16 + 1, 16 + 7) g) @@ -243,16 +245,18 @@ main = do sz ] ] - , bgroup "ShortByteString" - [ env (pure genLengths) $ \ ~(ns, gen) -> - bench "genShortByteString" $ - nfIO $ runStateGenT gen $ \g -> mapM (`uniformShortByteString` g) ns - ] - , bgroup "ByteString" - [ env getStdGen $ \gen -> - bench "genByteString 100MB" $ - nfIO $ runStateGenT gen $ uniformByteStringM 100000000 - ] + ] + , bgroup "Bytes" + [ env (pure genLengths) $ \ ~(ns, gen) -> + bench "uniformShortByteStringM" $ + nfIO $ runStateGenT gen $ \g -> mapM (`uniformShortByteStringM` g) ns + , env getStdGen $ \gen -> + bench "uniformByteStringM 100MB" $ + nf (runStateGen gen . uniformByteStringM) sz100MiB + , env getStdGen $ \gen -> + bench "uniformByteArray 100MB" $ nf (\n -> uniformByteArray False n gen) sz100MiB + , env getStdGen $ \gen -> + bench "genByteString 100MB" $ nf (\k -> genByteString k gen) sz100MiB ] ] ] diff --git a/random.cabal b/random.cabal index 47a87c0f..ddcef7d5 100644 --- a/random.cabal +++ b/random.cabal @@ -100,6 +100,8 @@ library deepseq >=1.1 && <2, mtl >=2.2 && <2.4, splitmix >=0.1 && <0.2 + if impl(ghc < 9.4) + build-depends: data-array-byte test-suite legacy-test type: exitcode-stdio-1.0 diff --git a/src/System/Random.hs b/src/System/Random.hs index 92f31faa..9290d242 100644 --- a/src/System/Random.hs +++ b/src/System/Random.hs @@ -20,14 +20,28 @@ module System.Random -- * Pure number generator interface -- $interfaces - RandomGen(..) + RandomGen + ( split + , genWord8 + , genWord16 + , genWord32 + , genWord64 + , genWord32R + , genWord64R + , unsafeUniformFillMutableByteArray + ) , uniform , uniformR - , genByteString , Random(..) , Uniform , UniformRange , Finite + -- * Generators for sequences of pseudo-random bytes + , uniformByteArray + , uniformByteString + , uniformFillMutableByteArray + , genByteString + , genShortByteString -- ** Standard pseudo-random number generator , StdGen @@ -45,6 +59,8 @@ module System.Random -- * Compatibility and reproducibility -- ** Backwards compatibility and deprecations + , genRange + , next -- $deprecations -- ** Reproducibility @@ -199,6 +215,9 @@ uniformR r g = runStateGen g (uniformRM r) -- >>> unpack . fst . genByteString 10 $ pureGen -- [51,123,251,37,49,167,90,109,1,4] -- +-- /Note/ - This function is equivalet to `uniformByteString` and will be deprecated in +-- the next major release. +-- -- @since 1.2.0 genByteString :: RandomGen g => Int -> g -> (ByteString, g) genByteString n g = runStateGenST g (uniformByteStringM n) diff --git a/src/System/Random/Internal.hs b/src/System/Random/Internal.hs index 6724ad34..4c87f18b 100644 --- a/src/System/Random/Internal.hs +++ b/src/System/Random/Internal.hs @@ -54,7 +54,6 @@ module System.Random.Internal , Uniform(..) , uniformViaFiniteM , UniformRange(..) - , uniformByteStringM , uniformDouble01M , uniformDoublePositive01M , uniformFloat01M @@ -63,19 +62,31 @@ module System.Random.Internal , uniformEnumRM -- * Generators for sequences of pseudo-random bytes + , uniformByteStringM + , uniformShortByteStringM + , uniformByteArray + , uniformFillMutableByteArray + , uniformByteString + , genByteArrayST , genShortByteStringIO , genShortByteStringST + , defaultUnsafeUniformFillMutableByteArray + -- ** Helpers for dealing with MutableByteArray + , newMutableByteArray + , newPinnedMutableByteArray + , freezeMutableByteArray + , writeWord8 ) where import Control.Arrow import Control.DeepSeq (NFData) import Control.Monad (when, (>=>)) import Control.Monad.Cont (ContT, runContT) -import Control.Monad.IO.Class (MonadIO(..)) +import Control.Monad.Identity (IdentityT (runIdentityT)) import Control.Monad.ST -import Control.Monad.ST.Unsafe -import Control.Monad.State.Strict (MonadState(..), State, StateT(..), runState) -import Control.Monad.Trans (lift) +import Control.Monad.State.Strict (MonadState(..), State, StateT(..), execStateT, runState) +import Control.Monad.Trans (lift, MonadTrans) +import Data.Array.Byte (ByteArray(..), MutableByteArray(..)) import Data.Bits import Data.ByteString.Short.Internal (ShortByteString(SBS), fromShort) import Data.IORef (IORef, newIORef) @@ -86,6 +97,7 @@ import Foreign.Storable (Storable) import GHC.Exts import GHC.Generics import GHC.IO (IO(..)) +import GHC.ST (ST(..)) import GHC.Word import Numeric.Natural (Natural) import System.IO.Unsafe (unsafePerformIO) @@ -178,15 +190,37 @@ class RandomGen g where genWord64R m g = runStateGen g (unsignedBitmaskWithRejectionM uniformWord64 m) {-# INLINE genWord64R #-} - -- | @genShortByteString n g@ returns a 'ShortByteString' of length @n@ - -- filled with pseudo-random bytes. + -- | Same as @`uniformByteArray` `False`@, but for `ShortByteString`. + -- + -- @genShortByteString n g@ returns a 'ShortByteString' of length @n@ filled with + -- pseudo-random bytes. + -- + -- /Note/ - This function will be removed from the type class in the next major release as + -- it is no longer needed because of `unsafeUniformFillMutableByteArray`. + -- -- -- @since 1.2.0 genShortByteString :: Int -> g -> (ShortByteString, g) genShortByteString n g = - unsafePerformIO $ runStateGenT g (genShortByteStringIO n . uniformWord64) + case uniformByteArray False n g of + (ByteArray ba#, g') -> (SBS ba#, g') {-# INLINE genShortByteString #-} + unsafeUniformFillMutableByteArray :: + MutableByteArray s + -- ^ Mutable array to fill with random bytes + -> Int + -- ^ Offset into a mutable array from the beginning in number of bytes. Offset must + -- be non-negative, but this will not be checked + -> Int + -- ^ Number of randomly generated bytes to write into the array. Number of bytes + -- must be non-negative and less then the total size of the array, minus the + -- offset. This also will be checked. + -> g + -> ST s g + unsafeUniformFillMutableByteArray = defaultUnsafeUniformFillMutableByteArray + {-# INLINE unsafeUniformFillMutableByteArray #-} + -- | Yields the range of values returned by 'next'. -- -- It is required that: @@ -277,14 +311,28 @@ class Monad m => StatefulGen g m where pure (shiftL (fromIntegral h32) 32 .|. fromIntegral l32) {-# INLINE uniformWord64 #-} + -- | @uniformByteArrayM n g@ generates a 'ByteArray' of length @n@ + -- filled with pseudo-random bytes. + -- + -- @since 1.3.0 + uniformByteArrayM :: + Bool -- ^ Should `ByteArray` be allocated as pinned memory or not + -> Int -- ^ Size of the newly created `ByteArray` in number of bytes. + -> g -- ^ Generator to use for filling in the newly created `ByteArray` + -> m ByteArray + default uniformByteArrayM :: + (RandomGen f, FrozenGen f m, g ~ MutableGen f m) => Bool -> Int -> g -> m ByteArray + uniformByteArrayM isPinned n g = modifyGen g (uniformByteArray isPinned n) + {-# INLINE uniformByteArrayM #-} + -- | @uniformShortByteString n g@ generates a 'ShortByteString' of length @n@ -- filled with pseudo-random bytes. -- -- @since 1.2.0 uniformShortByteString :: Int -> g -> m ShortByteString - default uniformShortByteString :: MonadIO m => Int -> g -> m ShortByteString - uniformShortByteString n = genShortByteStringIO n . uniformWord64 + uniformShortByteString = uniformShortByteStringM {-# INLINE uniformShortByteString #-} +{-# DEPRECATED uniformShortByteString "In favor of `uniformShortByteStringM`" #-} -- | This class is designed for mutable pseudo-random number generators that have a frozen @@ -382,57 +430,169 @@ splitGen = flip modifyGen split splitMutableGen :: (RandomGen f, ThawedGen f m) => MutableGen f m -> m (MutableGen f m) splitMutableGen = splitGen >=> thawGen - -data MBA = MBA (MutableByteArray# RealWorld) - - -- | Efficiently generates a sequence of pseudo-random bytes in a platform -- independent manner. -- --- @since 1.2.0 -genShortByteStringIO :: - MonadIO m - => Int -- ^ Number of bytes to generate - -> m Word64 -- ^ IO action that can generate 8 random bytes at a time - -> m ShortByteString -genShortByteStringIO n0 gen64 = do - let !n@(I# n#) = max 0 n0 - !n64 = n `quot` 8 +-- @since 1.3.0 +uniformByteArray :: + RandomGen g + => Bool -- ^ Should byte array be allocted in pinned or unpinned memory. + -> Int -- ^ Number of bytes to generate + -> g -- ^ Pure pseudo-random numer generator + -> (ByteArray, g) +uniformByteArray isPinned n0 g = + runST $ do + let !n = max 0 n0 + mba <- + if isPinned + then newPinnedMutableByteArray n + else newMutableByteArray n + g' <- unsafeUniformFillMutableByteArray mba 0 n g + ba <- freezeMutableByteArray mba + pure (ba, g') +{-# INLINE uniformByteArray #-} + +-- | Using an `ST` action that generates 8 bytes at a type fill in a new `ByteArray` in +-- architecture agnostic manner. +-- +-- @since 1.3.0 +genByteArrayST :: Bool -> Int -> ST s Word64 -> ST s ByteArray +genByteArrayST isPinned n0 action = do + let !n = max 0 n0 + mba <- if isPinned + then newPinnedMutableByteArray n + else newMutableByteArray n + runIdentityT $ defaultUnsafeUniformFillMutableByteArrayT mba 0 n (lift action) + freezeMutableByteArray mba +{-# INLINE genByteArrayST #-} + +-- | Fill in a slice of a mutable byte array with randomly generated bytes. This function +-- does not fail, instead it adjust the offset and number of bytes to generate into a valid +-- range. +-- +-- @since 1.3.0 +uniformFillMutableByteArray :: + RandomGen g + => MutableByteArray s + -- ^ Mutable array to fill with random bytes + -> Int + -- ^ Offset into a mutable array from the beginning in number of bytes. Offset will be + -- clamped into the range between 0 and the total size of the mutable array + -> Int + -- ^ Number of randomly generated bytes to write into the array. This number will be + -- clamped between 0 and the total size of the array without the offset. + -> g + -> ST s g +uniformFillMutableByteArray mba i0 n g = do + !sz <- getSizeOfMutableByteArray mba + let !offset = max 0 (min sz i0) + !numBytes = min (sz - offset) (max 0 n) + unsafeUniformFillMutableByteArray mba offset numBytes g +{-# INLINE uniformFillMutableByteArray #-} + +defaultUnsafeUniformFillMutableByteArrayT :: + (Monad (t (ST s)), MonadTrans t) + => MutableByteArray s + -> Int + -> Int + -> t (ST s) Word64 + -> t (ST s) () +defaultUnsafeUniformFillMutableByteArrayT mba offset n gen64 = do + let !n64 = n `quot` 8 + !endIx64 = offset + n64 * 8 !nrem = n `rem` 8 - mba@(MBA mba#) <- - liftIO $ IO $ \s# -> - case newByteArray# n# s# of - (# s'#, mba# #) -> (# s'#, MBA mba# #) - let go i = - when (i < n64) $ do + let go !i = + when (i < endIx64) $ do w64 <- gen64 -- Writing 8 bytes at a time in a Little-endian order gives us -- platform portability - liftIO $ writeWord64LE mba i w64 - go (i + 1) - go 0 + lift $ writeWord64LE mba i w64 + go (i + 8) + go offset when (nrem > 0) $ do + let !endIx = offset + n w64 <- gen64 -- In order to not mess up the byte order we write 1 byte at a time in -- Little endian order. It is tempting to simply generate as many bytes as we -- still need using smaller generators (eg. uniformWord8), but that would -- result in inconsistent tail when total length is slightly varied. - liftIO $ writeByteSliceWord64LE mba (n - nrem) n w64 - liftIO $ IO $ \s# -> - case unsafeFreezeByteArray# mba# s# of - (# s'#, ba# #) -> (# s'#, SBS ba# #) -{-# INLINE genShortByteStringIO #-} + lift $ writeByteSliceWord64LE mba (endIx - nrem) endIx w64 +{-# INLINEABLE defaultUnsafeUniformFillMutableByteArrayT #-} +{-# SPECIALIZE defaultUnsafeUniformFillMutableByteArrayT + :: MutableByteArray s + -> Int + -> Int + -> IdentityT (ST s) Word64 + -> IdentityT (ST s) () #-} +{-# SPECIALIZE defaultUnsafeUniformFillMutableByteArrayT + :: MutableByteArray s + -> Int + -> Int + -> StateT g (ST s) Word64 + -> StateT g (ST s) () #-} + +-- | Efficiently generates a sequence of pseudo-random bytes in a platform +-- independent manner. +-- +-- @since 1.2.0 +defaultUnsafeUniformFillMutableByteArray :: + RandomGen g + => MutableByteArray s + -> Int -- ^ Starting offset + -> Int -- ^ Number of random bytes to write into the array + -> g -- ^ ST action that can generate 8 random bytes at a time + -> ST s g +defaultUnsafeUniformFillMutableByteArray mba i0 n g = + flip execStateT g + $ defaultUnsafeUniformFillMutableByteArrayT mba i0 n (state genWord64) +{-# INLINE defaultUnsafeUniformFillMutableByteArray #-} + + +-- | Generates a pseudo-random 'ByteString' of the specified size. +-- +-- @since 1.3.0 +uniformByteString :: RandomGen g => Int -> g -> (ByteString, g) +uniformByteString n g = + case uniformByteArray True n g of + (byteArray, g') -> + (shortByteStringToByteString $ byteArrayToShortByteString byteArray, g') +{-# INLINE uniformByteString #-} -- Architecture independent helpers: -io_ :: (State# RealWorld -> State# RealWorld) -> IO () -io_ m# = IO $ \s# -> (# m# s#, () #) -{-# INLINE io_ #-} -writeWord8 :: MBA -> Int -> Word8 -> IO () -writeWord8 (MBA mba#) (I# i#) (W8# w#) = io_ (writeWord8Array# mba# i# w#) +st_ :: (State# s -> State# s) -> ST s () +st_ m# = ST $ \s# -> (# m# s#, () #) +{-# INLINE st_ #-} + +ioToST :: IO a -> ST RealWorld a +ioToST (IO m#) = ST m# +{-# INLINE ioToST #-} + +newMutableByteArray :: Int -> ST s (MutableByteArray s) +newMutableByteArray (I# n#) = + ST $ \s# -> + case newByteArray# n# s# of + (# s'#, mba# #) -> (# s'#, MutableByteArray mba# #) +{-# INLINE newMutableByteArray #-} + +newPinnedMutableByteArray :: Int -> ST s (MutableByteArray s) +newPinnedMutableByteArray (I# n#) = + ST $ \s# -> + case newPinnedByteArray# n# s# of + (# s'#, mba# #) -> (# s'#, MutableByteArray mba# #) +{-# INLINE newPinnedMutableByteArray #-} + +freezeMutableByteArray :: MutableByteArray s -> ST s ByteArray +freezeMutableByteArray (MutableByteArray mba#) = + ST $ \s# -> + case unsafeFreezeByteArray# mba# s# of + (# s'#, ba# #) -> (# s'#, ByteArray ba# #) + +writeWord8 :: MutableByteArray s -> Int -> Word8 -> ST s () +writeWord8 (MutableByteArray mba#) (I# i#) (W8# w#) = st_ (writeWord8Array# mba# i# w#) {-# INLINE writeWord8 #-} -writeByteSliceWord64LE :: MBA -> Int -> Int -> Word64 -> IO () +writeByteSliceWord64LE :: MutableByteArray s -> Int -> Int -> Word64 -> ST s () writeByteSliceWord64LE mba fromByteIx toByteIx = go fromByteIx where go !i !z = @@ -441,48 +601,53 @@ writeByteSliceWord64LE mba fromByteIx toByteIx = go fromByteIx go (i + 1) (z `shiftR` 8) {-# INLINE writeByteSliceWord64LE #-} -writeWord64LE :: MBA -> Int -> Word64 -> IO () -#ifdef WORDS_BIGENDIAN -writeWord64LE mba i w64 = do - let !i8 = i * 8 - writeByteSliceWord64LE mba i8 (i8 + 8) w64 +-- On big endian machines we need to write one byte at a time for consistency with little +-- endian machines. Also for GHC versions prior to 8.6 we don't have primops that can +-- write with byte offset, eg. writeWord8ArrayAsWord64# and writeWord8ArrayAsWord32#, so we +-- also must fallback to writing one byte a time. Such fallback results in about 3 times +-- slow down, which is not the end of the world. +writeWord64LE :: MutableByteArray s -> Int -> Word64 -> ST s () +#if defined WORDS_BIGENDIAN || !(__GLASGOW_HASKELL__ >= 806) +writeWord64LE mba i w64 = + writeByteSliceWord64LE mba i (i + 8) w64 #else -writeWord64LE (MBA mba#) (I# i#) w64@(W64# w64#) - | wordSizeInBits == 64 = io_ (writeWord64Array# mba# i# w64#) +writeWord64LE (MutableByteArray mba#) (I# i#) w64@(W64# w64#) + | wordSizeInBits == 64 = st_ (writeWord8ArrayAsWord64# mba# i# w64#) | otherwise = do - let !i32# = i# *# 2# - !(W32# w32l#) = fromIntegral w64 + let !(W32# w32l#) = fromIntegral w64 !(W32# w32u#) = fromIntegral (w64 `shiftR` 32) - io_ (writeWord32Array# mba# i32# w32l#) - io_ (writeWord32Array# mba# (i32# +# 1#) w32u#) + st_ (writeWord8ArrayAsWord32# mba# i# w32l#) + st_ (writeWord8ArrayAsWord32# mba# (i# +# 4#) w32u#) #endif {-# INLINE writeWord64LE #-} +getSizeOfMutableByteArray :: MutableByteArray s -> ST s Int +getSizeOfMutableByteArray (MutableByteArray mba#) = +#if __GLASGOW_HASKELL__ >=802 + ST $ \s -> + case getSizeofMutableByteArray# mba# s of + (# s', n# #) -> (# s', I# n# #) +#else + pure $! I# (sizeofMutableByteArray# mba#) +#endif +{-# INLINE getSizeOfMutableByteArray #-} --- | Same as 'genShortByteStringIO', but runs in 'ST'. --- --- @since 1.2.0 -genShortByteStringST :: Int -> ST s Word64 -> ST s ShortByteString -genShortByteStringST n action = - unsafeIOToST (genShortByteStringIO n (unsafeSTToIO action)) -{-# INLINE genShortByteStringST #-} - +byteArrayToShortByteString :: ByteArray -> ShortByteString +byteArrayToShortByteString (ByteArray ba#) = SBS ba# +{-# INLINE byteArrayToShortByteString #-} --- | Generates a pseudo-random 'ByteString' of the specified size. --- --- @since 1.2.0 -uniformByteStringM :: StatefulGen g m => Int -> g -> m ByteString -uniformByteStringM n g = do - ba <- uniformShortByteString n g - pure $ +-- | Convert a ShortByteString to ByteString by casting, whenever memory is pinned, +-- otherwise make a copy into a new pinned ByteString +shortByteStringToByteString :: ShortByteString -> ByteString +shortByteStringToByteString ba = #if __GLASGOW_HASKELL__ < 802 - fromShort ba + fromShort ba #else - let !(SBS ba#) = ba in - if isTrue# (isByteArrayPinned# ba#) - then pinnedByteArrayToByteString ba# - else fromShort ba -{-# INLINE uniformByteStringM #-} + let !(SBS ba#) = ba in + if isTrue# (isByteArrayPinned# ba#) + then pinnedByteArrayToByteString ba# + else fromShort ba +{-# INLINE shortByteStringToByteString #-} pinnedByteArrayToByteString :: ByteArray# -> ByteString pinnedByteArrayToByteString ba# = @@ -495,6 +660,40 @@ pinnedByteArrayToForeignPtr ba# = {-# INLINE pinnedByteArrayToForeignPtr #-} #endif +-- | Same as 'genShortByteStringIO', but runs in 'ST'. +-- +-- @since 1.2.0 +genShortByteStringST :: Int -> ST s Word64 -> ST s ShortByteString +genShortByteStringST n0 action = byteArrayToShortByteString <$> genByteArrayST False n0 action +{-# INLINE genShortByteStringST #-} + +-- | Efficiently fills in a new `ShortByteString` in a platform independent manner. +-- +-- @since 1.2.0 +genShortByteStringIO :: + Int -- ^ Number of bytes to generate + -> IO Word64 -- ^ IO action that can generate 8 random bytes at a time + -> IO ShortByteString +genShortByteStringIO n ioAction = stToIO $ genShortByteStringST n (ioToST ioAction) +{-# INLINE genShortByteStringIO #-} + +-- | @uniformShortByteString n g@ generates a 'ShortByteString' of length @n@ +-- filled with pseudo-random bytes. +-- +-- @since 1.3.0 +uniformShortByteStringM :: StatefulGen g m => Int -> g -> m ShortByteString +uniformShortByteStringM n g = byteArrayToShortByteString <$> uniformByteArrayM False n g +{-# INLINE uniformShortByteStringM #-} + +-- | Generates a pseudo-random 'ByteString' of the specified size. +-- +-- @since 1.2.0 +uniformByteStringM :: StatefulGen g m => Int -> g -> m ByteString +uniformByteStringM n g = + shortByteStringToByteString . byteArrayToShortByteString + <$> uniformByteArrayM True n g +{-# INLINE uniformByteStringM #-} + -- | Opaque data type that carries the type of a pure pseudo-random number -- generator. @@ -522,8 +721,6 @@ instance (RandomGen g, MonadState g m) => StatefulGen (StateGenM g) m where {-# INLINE uniformWord32 #-} uniformWord64 _ = state genWord64 {-# INLINE uniformWord64 #-} - uniformShortByteString n _ = state (genShortByteString n) - {-# INLINE uniformShortByteString #-} instance (RandomGen g, MonadState g m) => FrozenGen (StateGen g) m where type MutableGen (StateGen g) m = StateGenM g @@ -629,6 +826,11 @@ instance RandomGen SM.SMGen where {-# INLINE genWord64 #-} split = SM.splitSMGen {-# INLINE split #-} + -- Despite that this is the same default implementation as in the type class definition, + -- for some mysterious reason without this overwrite, performance of ByteArray generation + -- slows down by a factor of x4: + unsafeUniformFillMutableByteArray = defaultUnsafeUniformFillMutableByteArray + {-# INLINE unsafeUniformFillMutableByteArray #-} instance RandomGen SM32.SMGen where next = SM32.nextInt diff --git a/src/System/Random/Stateful.hs b/src/System/Random/Stateful.hs index 0e42cc3c..20474cc6 100644 --- a/src/System/Random/Stateful.hs +++ b/src/System/Random/Stateful.hs @@ -28,7 +28,15 @@ module System.Random.Stateful -- * Mutable pseudo-random number generator interfaces -- $interfaces - , StatefulGen(..) + , StatefulGen + ( uniformWord32R + , uniformWord64R + , uniformWord8 + , uniformWord16 + , uniformWord32 + , uniformWord64 + , uniformShortByteString + ) , FrozenGen(..) , ThawedGen(..) , withMutableGen @@ -45,7 +53,7 @@ module System.Random.Stateful -- * Monadic adapters for pure pseudo-random number generators #monadicadapters# -- $monadicadapters - -- ** Pure adapter + -- ** Pure adapter in 'MonadState' , StateGen(..) , StateGenM(..) , runStateGen @@ -54,7 +62,7 @@ module System.Random.Stateful , runStateGenT_ , runStateGenST , runStateGenST_ - -- ** Mutable adapter with atomic operations + -- ** Mutable thread-safe adapter in 'IO' , AtomicGen(..) , AtomicGenM(..) , newAtomicGenM @@ -72,7 +80,7 @@ module System.Random.Stateful , applySTGen , runSTGen , runSTGen_ - -- ** Mutable adapter in 'STM' + -- ** Mutable thread-safe adapter in 'STM' , TGen(..) , TGenM(..) , newTGenM @@ -86,14 +94,23 @@ module System.Random.Stateful , uniformViaFiniteM , UniformRange(..) - -- * Generators for sequences of pseudo-random bytes + -- ** Generators for sequences of pseudo-random bytes + , uniformByteArrayM + , uniformByteStringM + , uniformShortByteStringM + + -- * Helper functions for createing instances + -- ** Sequences of bytes + , genByteArrayST , genShortByteStringIO , genShortByteStringST - , uniformByteStringM + , defaultUnsafeUniformFillMutableByteArray + -- ** Floating point numbers , uniformDouble01M , uniformDoublePositive01M , uniformFloat01M , uniformFloatPositive01M + -- ** Enum types , uniformEnumM , uniformEnumRM @@ -384,7 +401,6 @@ instance (RandomGen g, MonadIO m) => StatefulGen (AtomicGenM g) m where {-# INLINE uniformWord32 #-} uniformWord64 = applyAtomicGen genWord64 {-# INLINE uniformWord64 #-} - uniformShortByteString n = applyAtomicGen (genShortByteString n) instance (RandomGen g, MonadIO m) => FrozenGen (AtomicGen g) m where @@ -466,7 +482,6 @@ instance (RandomGen g, MonadIO m) => StatefulGen (IOGenM g) m where {-# INLINE uniformWord32 #-} uniformWord64 = applyIOGen genWord64 {-# INLINE uniformWord64 #-} - uniformShortByteString n = applyIOGen (genShortByteString n) instance (RandomGen g, MonadIO m) => FrozenGen (IOGen g) m where @@ -536,7 +551,6 @@ instance RandomGen g => StatefulGen (STGenM g s) (ST s) where {-# INLINE uniformWord32 #-} uniformWord64 = applySTGen genWord64 {-# INLINE uniformWord64 #-} - uniformShortByteString n = applySTGen (genShortByteString n) instance RandomGen g => FrozenGen (STGen g) (ST s) where type MutableGen (STGen g) (ST s) = STGenM g s @@ -640,7 +654,6 @@ instance RandomGen g => StatefulGen (TGenM g) STM where {-# INLINE uniformWord32 #-} uniformWord64 = applyTGen genWord64 {-# INLINE uniformWord64 #-} - uniformShortByteString n = applyTGen (genShortByteString n) -- | @since 1.2.1 instance RandomGen g => FrozenGen (TGen g) STM where @@ -795,17 +808,23 @@ applyTGen f (TGenM tvar) = do -- Here is an example instance for the monadic pseudo-random number generator -- from the @mwc-random@ package: -- +-- > import qualified System.Random.MWC as MWC +-- > import qualified Data.Vector.Generic as G +-- -- > instance (s ~ PrimState m, PrimMonad m) => StatefulGen (MWC.Gen s) m where -- > uniformWord8 = MWC.uniform -- > uniformWord16 = MWC.uniform -- > uniformWord32 = MWC.uniform -- > uniformWord64 = MWC.uniform --- > uniformShortByteString n g = stToPrim (genShortByteStringST n (MWC.uniform g)) +-- > uniformByteArrayM isPinned n g = stToPrim (genByteArrayST isPinned n (MWC.uniform g)) -- -- > instance PrimMonad m => FrozenGen MWC.Seed m where -- > type MutableGen MWC.Seed m = MWC.Gen (PrimState m) --- > thawGen = MWC.restore -- > freezeGen = MWC.save +-- > overwriteGen (Gen mv) (Seed v) = G.copy mv v +-- +-- > instance PrimMonad m => ThawedGen MWC.Seed m where +-- > thawGen = MWC.restore -- -- === @FrozenGen@ -- diff --git a/stack-coveralls.yaml b/stack-coveralls.yaml index 6a2143bc..30d15974 100644 --- a/stack-coveralls.yaml +++ b/stack-coveralls.yaml @@ -2,4 +2,5 @@ resolver: lts-19.33 system-ghc: true packages: - . -extra-deps: [] +extra-deps: +- data-array-byte-0.1.0.1@sha256:ad89e28b2b046175698fbf542af2ce43e5d2af50aae9f48d12566b1bb3de1d3c,1989 diff --git a/stack-old.yaml b/stack-old.yaml index 025127c5..5c749977 100644 --- a/stack-old.yaml +++ b/stack-old.yaml @@ -16,3 +16,4 @@ extra-deps: - tasty-inspection-testing-0.1@sha256:9c5e76345168fd3a59b43d305eebf8df3c792ce324c66bbdee45b54aa7d2c0ad,1214 - primitive-0.7.4.0@sha256:89b88a3e08493b7727fa4089b0692bfbdf7e1e666ef54635f458644eb8358764,2857 - vector-0.12.3.1@sha256:fffbd00912d69ed7be9bc7eeb09f4f475e0d243ec43f916a9fd5bbd219ce7f3e,8238 +- data-array-byte-0.1.0.1@sha256:ad89e28b2b046175698fbf542af2ce43e5d2af50aae9f48d12566b1bb3de1d3c,1989 diff --git a/stack-oldish.yaml b/stack-oldish.yaml new file mode 100644 index 00000000..06ec8886 --- /dev/null +++ b/stack-oldish.yaml @@ -0,0 +1,5 @@ +resolver: lts-18.28 +packages: +- . +extra-deps: +- data-array-byte-0.1.0.1@sha256:2ef1bd3511e82ba56f7f23cd793dd2da84338a1e7c2cbea5b151417afe3baada,1989 diff --git a/stack.lts-18.yaml b/stack.lts-18.yaml new file mode 100644 index 00000000..320440f4 --- /dev/null +++ b/stack.lts-18.yaml @@ -0,0 +1,5 @@ +resolver: lts-18.28 +packages: +- . +extra-deps: +- data-array-byte-0.1.0.1@sha256:ad89e28b2b046175698fbf542af2ce43e5d2af50aae9f48d12566b1bb3de1d3c,1989 diff --git a/stack.yaml b/stack.yaml index d15416fa..7f9dc000 100644 --- a/stack.yaml +++ b/stack.yaml @@ -1,4 +1,4 @@ -resolver: lts-18.28 +resolver: lts-21.17 packages: - . extra-deps: [] diff --git a/test/Spec.hs b/test/Spec.hs index 078d4d0a..95f139ad 100644 --- a/test/Spec.hs +++ b/test/Spec.hs @@ -9,6 +9,7 @@ module Main (main) where import Control.Monad (replicateM, forM_) +import Control.Monad.ST (runST) import qualified Data.ByteString as BS import qualified Data.ByteString.Short as SBS import Data.Int @@ -17,12 +18,17 @@ import Data.Void import Data.Word import Foreign.C.Types import GHC.Generics +import GHC.Exts (fromList) import Numeric.Natural (Natural) import System.Random.Stateful +import System.Random.Internal (newMutableByteArray, freezeMutableByteArray, writeWord8) import Test.SmallCheck.Series as SC import Test.Tasty import Test.Tasty.HUnit import Test.Tasty.SmallCheck as SC +#if __GLASGOW_HASKELL__ < 804 +import Data.Monoid ((<>)) +#endif import qualified Spec.Range as Range import qualified Spec.Run as Run @@ -80,6 +86,7 @@ main = , runSpec , floatTests , byteStringSpec + , fillMutableByteArraySpec , SC.testProperty "uniformRangeWithinExcludedF" $ seeded Range.uniformRangeWithinExcludedF , SC.testProperty "uniformRangeWithinExcludedD" $ seeded Range.uniformRangeWithinExcludedD , randomSpec (Proxy :: Proxy (CFloat, CDouble)) @@ -125,12 +132,39 @@ byteStringSpec = let g = mkStdGen 2021 bs = [78,232,117,189,13,237,63,84,228,82,19,36,191,5,128,192] :: [Word8] forM_ [0 .. length bs - 1] $ \ n -> do - xs <- SBS.unpack <$> runStateGenT_ g (uniformShortByteString n) + xs <- SBS.unpack <$> runStateGenT_ g (uniformShortByteStringM n) xs @?= take n bs ys <- BS.unpack <$> runStateGenT_ g (uniformByteStringM n) ys @?= xs ] +fillMutableByteArraySpec :: TestTree +fillMutableByteArraySpec = + testGroup + "MutableByteArray" + [ SC.testProperty "Same as uniformByteArray" $ + forAll $ \isPinned -> seededWithLen $ \n g -> + let baFilled = runST $ do + mba <- newMutableByteArray n + g' <- uniformFillMutableByteArray mba 0 n g + ba <- freezeMutableByteArray mba + pure (ba, g') + in baFilled == uniformByteArray isPinned n g + , SC.testProperty "Safe uniformFillMutableByteArray" $ + forAll $ \isPinned offset count -> seededWithLen $ \sz g -> + let (baFilled, gf) = runST $ do + mba <- newMutableByteArray sz + forM_ [0 .. sz - 1] (\i -> writeWord8 mba i 0) + g' <- uniformFillMutableByteArray mba offset count g + ba <- freezeMutableByteArray mba + pure (ba, g') + (baGen, gu) = uniformByteArray isPinned count' g + offset' = min sz (max 0 offset) + count' = min (sz - offset') (max 0 count) + prefix = replicate offset' 0 + suffix = replicate (sz - (count' + offset')) 0 + in gf == gu && baFilled == fromList prefix <> baGen <> fromList suffix + ] rangeSpec :: forall a. diff --git a/test/Spec/Stateful.hs b/test/Spec/Stateful.hs index dbed18c4..167dcc16 100644 --- a/test/Spec/Stateful.hs +++ b/test/Spec/Stateful.hs @@ -155,15 +155,33 @@ frozenGenSpecFor fromStdGen toStdGen runStatefulGen = , testProperty "uniformWord64R/genWord64R" $ forAll $ \w64 -> matchRandomGenSpec (uniformWord64R w64) (genWord64R w64) fromStdGen toStdGen runStatefulGen - , testProperty "uniformShortByteString/genShortByteString" $ + , testProperty "uniformShortByteStringM/genShortByteString" $ forAll $ \(NonNegative n') -> let n = n' `mod` 100000 -- Ensure it is not too big in matchRandomGenSpec - (uniformShortByteString n) + (uniformShortByteStringM n) (genShortByteString n) fromStdGen toStdGen runStatefulGen + , testProperty "uniformByteStringM/uniformByteString" $ + forAll $ \(NonNegative n') -> + let n = n' `mod` 100000 -- Ensure it is not too big + in matchRandomGenSpec + (uniformByteStringM n) + (uniformByteString n) + fromStdGen + toStdGen + runStatefulGen + , testProperty "uniformByteArrayM/genByteArray" $ + forAll $ \(NonNegative n', isPinned1, isPinned2) -> + let n = n' `mod` 100000 -- Ensure it is not too big + in matchRandomGenSpec + (uniformByteArrayM isPinned1 n) + (uniformByteArray isPinned2 n) + fromStdGen + toStdGen + runStatefulGen ] ]