From b8c0ae4b8b1e8d4b5f491a9ce182b923a6d63fff Mon Sep 17 00:00:00 2001 From: Alexey Kuleshevich Date: Sun, 29 Oct 2023 01:32:53 +0200 Subject: [PATCH] WIP RandomGen is capable of generating ByteArray progress on bytestring generation Improve haddock Get the StatefulGen generation of ByteArray implemented Get it all building and tests passing Expose some useful helper functions --- random.cabal | 2 + src/System/Random.hs | 4 + src/System/Random/Internal.hs | 284 +++++++++++++++++++++++++++------- src/System/Random/Stateful.hs | 25 ++- stack-oldish.yaml | 5 + stack.yaml | 2 +- test/Spec.hs | 2 +- test/Spec/Stateful.hs | 2 +- 8 files changed, 261 insertions(+), 65 deletions(-) create mode 100644 stack-oldish.yaml diff --git a/random.cabal b/random.cabal index 47a87c0f..ddcef7d5 100644 --- a/random.cabal +++ b/random.cabal @@ -100,6 +100,8 @@ library deepseq >=1.1 && <2, mtl >=2.2 && <2.4, splitmix >=0.1 && <0.2 + if impl(ghc < 9.4) + build-depends: data-array-byte test-suite legacy-test type: exitcode-stdio-1.0 diff --git a/src/System/Random.hs b/src/System/Random.hs index 92f31faa..695cb63f 100644 --- a/src/System/Random.hs +++ b/src/System/Random.hs @@ -28,6 +28,10 @@ module System.Random , Uniform , UniformRange , Finite + -- * Generators for sequences of pseudo-random bytes + , uniformByteArray + , uniformByteString + , uniformFillMutableByteArray -- ** Standard pseudo-random number generator , StdGen diff --git a/src/System/Random/Internal.hs b/src/System/Random/Internal.hs index cf0af096..3172b2fa 100644 --- a/src/System/Random/Internal.hs +++ b/src/System/Random/Internal.hs @@ -55,6 +55,7 @@ module System.Random.Internal , uniformViaFiniteM , UniformRange(..) , uniformByteStringM + , uniformShortByteStringM , uniformDouble01M , uniformDoublePositive01M , uniformFloat01M @@ -63,19 +64,24 @@ module System.Random.Internal , uniformEnumRM -- * Generators for sequences of pseudo-random bytes + , uniformByteArray + , uniformFillMutableByteArray + , uniformByteString + , genByteArrayST , genShortByteStringIO , genShortByteStringST + , defaultUnsafeUniformFillMutableByteArray ) where import Control.Arrow import Control.DeepSeq (NFData) import Control.Monad (when, (>=>)) import Control.Monad.Cont (ContT, runContT) -import Control.Monad.IO.Class (MonadIO(..)) +import Control.Monad.Identity (runIdentityT) import Control.Monad.ST -import Control.Monad.ST.Unsafe -import Control.Monad.State.Strict (MonadState(..), State, StateT(..), runState) -import Control.Monad.Trans (lift) +import Control.Monad.State.Strict (MonadState(..), State, StateT(..), execStateT, runState) +import Control.Monad.Trans (lift, MonadTrans) +import Data.Array.Byte (ByteArray(..), MutableByteArray(..)) import Data.Bits import Data.ByteString.Short.Internal (ShortByteString(SBS), fromShort) import Data.IORef (IORef, newIORef) @@ -86,6 +92,7 @@ import Foreign.Storable (Storable) import GHC.Exts import GHC.Generics import GHC.IO (IO(..)) +import GHC.ST (ST(..)) import GHC.Word import Numeric.Natural (Natural) import System.IO.Unsafe (unsafePerformIO) @@ -178,15 +185,31 @@ class RandomGen g where genWord64R m g = runStateGen g (unsignedBitmaskWithRejectionM uniformWord64 m) {-# INLINE genWord64R #-} - -- | @genShortByteString n g@ returns a 'ShortByteString' of length @n@ - -- filled with pseudo-random bytes. + -- | Same as `uniformByteArray`, but for `ShortByteString`. @genShortByteString n g@ returns + -- a 'ShortByteString' of length @n@ filled with pseudo-random bytes. -- -- @since 1.2.0 genShortByteString :: Int -> g -> (ShortByteString, g) genShortByteString n g = - unsafePerformIO $ runStateGenT g (genShortByteStringIO n . uniformWord64) + case uniformByteArray False n g of + (ByteArray ba#, g') -> (SBS ba#, g') {-# INLINE genShortByteString #-} + unsafeUniformFillMutableByteArray :: + MutableByteArray s + -- ^ Mutable array to fill with random bytes + -> Int + -- ^ Offset into a mutable array from the beginning in number of bytes. Offset must + -- be non-negative, but this will not be checked + -> Int + -- ^ Number of randomly generated bytes to write into the array. Number of bytes + -- must be non-negative and less then the total size of the array, minus the + -- offset. This also will be checked. + -> g + -> ST s g + unsafeUniformFillMutableByteArray = defaultUnsafeUniformFillMutableByteArray + {-# INLINE unsafeUniformFillMutableByteArray #-} + -- | Yields the range of values returned by 'next'. -- -- It is required that: @@ -277,16 +300,29 @@ class Monad m => StatefulGen g m where pure (shiftL (fromIntegral h32) 32 .|. fromIntegral l32) {-# INLINE uniformWord64 #-} + -- | @uniformByteArrayM n g@ generates a 'ByteArray' of length @n@ + -- filled with pseudo-random bytes. + -- + -- @since 1.3.0 + uniformByteArrayM :: + Bool -- ^ Should `ByteArray` be allocated as pinned memory or not + -> Int -- ^ Size of the newly created `ByteArray` in number of bytes. + -> g -- ^ Generator to use for filling in the newly created `ByteArray` + -> m ByteArray + default uniformByteArrayM :: + (RandomGen f, FrozenGen f m, g ~ MutableGen f m) => Bool -> Int -> g -> m ByteArray + uniformByteArrayM isPinned n g = modifyM g (uniformByteArray isPinned n) + {-# INLINE uniformByteArrayM #-} + -- | @uniformShortByteString n g@ generates a 'ShortByteString' of length @n@ -- filled with pseudo-random bytes. -- -- @since 1.2.0 uniformShortByteString :: Int -> g -> m ShortByteString - default uniformShortByteString :: MonadIO m => Int -> g -> m ShortByteString - uniformShortByteString n = genShortByteStringIO n . uniformWord64 + uniformShortByteString = uniformShortByteStringM {-# INLINE uniformShortByteString #-} - +{-# DEPRECATED uniformShortByteString "In favor of `uniformShortByteStringM`" #-} -- | This class is designed for stateful pseudo-random number generators that -- can be saved as and restored from an immutable data type. @@ -330,57 +366,154 @@ splitFrozenM = flip modifyM split splitMutableM :: (RandomGen f, FrozenGen f m) => MutableGen f m -> m (MutableGen f m) splitMutableM = splitFrozenM >=> thawGen - -data MBA = MBA (MutableByteArray# RealWorld) - - -- | Efficiently generates a sequence of pseudo-random bytes in a platform -- independent manner. -- --- @since 1.2.0 -genShortByteStringIO :: - MonadIO m - => Int -- ^ Number of bytes to generate - -> m Word64 -- ^ IO action that can generate 8 random bytes at a time - -> m ShortByteString -genShortByteStringIO n0 gen64 = do - let !n@(I# n#) = max 0 n0 - !n64 = n `quot` 8 +-- @since 1.3.0 +uniformByteArray :: + RandomGen g + => Bool -- ^ Should byte array be allocted in pinned or unpinned memory. + -> Int -- ^ Number of bytes to generate + -> g -- ^ Pure pseudo-random numer generator + -> (ByteArray, g) +uniformByteArray isPinned n0 g = + runST $ do + let !n = max 0 n0 + mba <- + if isPinned + then newPinnedMutableByteArray n + else newMutableByteArray n + g' <- unsafeUniformFillMutableByteArray mba 0 n g + ba <- freezeMutableByteArray mba + pure (ba, g') +{-# INLINE uniformByteArray #-} + +-- | Using an `ST` action that generates 8 bytes at a type fill in a new `ByteArray` in +-- architecture agnostic manner. +-- +-- @since 1.3.0 +genByteArrayST :: Bool -> Int -> ST s Word64 -> ST s ByteArray +genByteArrayST isPinned n0 action = do + let !n = max 0 n0 + mba <- if isPinned + then newPinnedMutableByteArray n + else newMutableByteArray n + runIdentityT $ defaultUnsafeUniformFillMutableByteArrayT mba 0 n (lift action) + freezeMutableByteArray mba +{-# INLINE genByteArrayST #-} + +-- | Fill in a slice of a mutable byte array with randomly generated bytes. This +-- function does not fail. +-- +-- @since 1.3.0 +uniformFillMutableByteArray :: + RandomGen g + => MutableByteArray s + -- ^ Mutable array to fill with random bytes + -> Int + -- ^ Offset into a mutable array from the beginning in number of bytes. Whenever + -- negative offset is supplied it will be treaded as 0 + -> Int + -- ^ Number of randomly generated bytes to write into the array. Supplied number of + -- bytes will be clamped between 0 and the total size of the array, minus the + -- offset. + -> g + -> ST s g +uniformFillMutableByteArray mba i0 n g = do + sz <- getSizeOfMutableByteArray mba + let !offset = max (min 0 i0) sz + !numBytes = max (min 0 n) (sz - offset) + unsafeUniformFillMutableByteArray mba offset numBytes g +{-# INLINE uniformFillMutableByteArray #-} + +defaultUnsafeUniformFillMutableByteArrayT :: + (Monad (t (ST s)), MonadTrans t) + => MutableByteArray s + -> Int + -> Int + -> t (ST s) Word64 + -> t (ST s) () +defaultUnsafeUniformFillMutableByteArrayT mba i0 n gen64 = do + let !n64 = n `quot` 8 !nrem = n `rem` 8 - mba@(MBA mba#) <- - liftIO $ IO $ \s# -> - case newByteArray# n# s# of - (# s'#, mba# #) -> (# s'#, MBA mba# #) let go i = when (i < n64) $ do w64 <- gen64 -- Writing 8 bytes at a time in a Little-endian order gives us -- platform portability - liftIO $ writeWord64LE mba i w64 + lift $ writeWord64LE mba i w64 go (i + 1) - go 0 + go i0 when (nrem > 0) $ do w64 <- gen64 -- In order to not mess up the byte order we write 1 byte at a time in -- Little endian order. It is tempting to simply generate as many bytes as we -- still need using smaller generators (eg. uniformWord8), but that would -- result in inconsistent tail when total length is slightly varied. - liftIO $ writeByteSliceWord64LE mba (n - nrem) n w64 - liftIO $ IO $ \s# -> - case unsafeFreezeByteArray# mba# s# of - (# s'#, ba# #) -> (# s'#, SBS ba# #) -{-# INLINE genShortByteStringIO #-} + lift $ writeByteSliceWord64LE mba (n - nrem) n w64 +{-# INLINE defaultUnsafeUniformFillMutableByteArrayT #-} + +-- | Efficiently generates a sequence of pseudo-random bytes in a platform +-- independent manner. +-- +-- @since 1.2.0 +defaultUnsafeUniformFillMutableByteArray :: + RandomGen g + => MutableByteArray s + -> Int -- ^ Starting offset + -> Int -- ^ Number of random bytes to write into the array + -> g -- ^ ST action that can generate 8 random bytes at a time + -> ST s g +defaultUnsafeUniformFillMutableByteArray mba i0 n g = + flip execStateT g + $ defaultUnsafeUniformFillMutableByteArrayT mba i0 n (state genWord64) +{-# INLINE defaultUnsafeUniformFillMutableByteArray #-} + + +-- | Generates a pseudo-random 'ByteString' of the specified size. +-- +-- @since 1.3.0 +uniformByteString :: RandomGen g => Int -> g -> (ByteString, g) +uniformByteString n g = + case uniformByteArray True n g of + (byteArray, g') -> + (shortByteStringToByteString $ byteArrayToShortByteString byteArray, g') -- Architecture independent helpers: -io_ :: (State# RealWorld -> State# RealWorld) -> IO () -io_ m# = IO $ \s# -> (# m# s#, () #) -{-# INLINE io_ #-} -writeWord8 :: MBA -> Int -> Word8 -> IO () -writeWord8 (MBA mba#) (I# i#) (W8# w#) = io_ (writeWord8Array# mba# i# w#) +st_ :: (State# s -> State# s) -> ST s () +st_ m# = ST $ \s# -> (# m# s#, () #) +{-# INLINE st_ #-} + +ioToST :: IO a -> ST RealWorld a +ioToST (IO m#) = ST m# +{-# INLINE ioToST #-} + +newMutableByteArray :: Int -> ST s (MutableByteArray s) +newMutableByteArray (I# n#) = + ST $ \s# -> + case newByteArray# n# s# of + (# s'#, mba# #) -> (# s'#, MutableByteArray mba# #) +{-# INLINE newMutableByteArray #-} + +newPinnedMutableByteArray :: Int -> ST s (MutableByteArray s) +newPinnedMutableByteArray (I# n#) = + ST $ \s# -> + case newPinnedByteArray# n# s# of + (# s'#, mba# #) -> (# s'#, MutableByteArray mba# #) +{-# INLINE newPinnedMutableByteArray #-} + +freezeMutableByteArray :: MutableByteArray s -> ST s ByteArray +freezeMutableByteArray (MutableByteArray mba#) = + ST $ \s# -> + case unsafeFreezeByteArray# mba# s# of + (# s'#, ba# #) -> (# s'#, ByteArray ba# #) + +writeWord8 :: MutableByteArray s -> Int -> Word8 -> ST s () +writeWord8 (MutableByteArray mba#) (I# i#) (W8# w#) = st_ (writeWord8Array# mba# i# w#) {-# INLINE writeWord8 #-} -writeByteSliceWord64LE :: MBA -> Int -> Int -> Word64 -> IO () +writeByteSliceWord64LE :: MutableByteArray s -> Int -> Int -> Word64 -> ST s () writeByteSliceWord64LE mba fromByteIx toByteIx = go fromByteIx where go !i !z = @@ -389,48 +522,83 @@ writeByteSliceWord64LE mba fromByteIx toByteIx = go fromByteIx go (i + 1) (z `shiftR` 8) {-# INLINE writeByteSliceWord64LE #-} -writeWord64LE :: MBA -> Int -> Word64 -> IO () +writeWord64LE :: MutableByteArray s -> Int -> Word64 -> ST s () #ifdef WORDS_BIGENDIAN writeWord64LE mba i w64 = do let !i8 = i * 8 writeByteSliceWord64LE mba i8 (i8 + 8) w64 #else -writeWord64LE (MBA mba#) (I# i#) w64@(W64# w64#) - | wordSizeInBits == 64 = io_ (writeWord64Array# mba# i# w64#) +writeWord64LE (MutableByteArray mba#) (I# i#) w64@(W64# w64#) + | wordSizeInBits == 64 = st_ (writeWord64Array# mba# i# w64#) | otherwise = do let !i32# = i# *# 2# !(W32# w32l#) = fromIntegral w64 !(W32# w32u#) = fromIntegral (w64 `shiftR` 32) - io_ (writeWord32Array# mba# i32# w32l#) - io_ (writeWord32Array# mba# (i32# +# 1#) w32u#) + st_ (writeWord32Array# mba# i32# w32l#) + st_ (writeWord32Array# mba# (i32# +# 1#) w32u#) #endif {-# INLINE writeWord64LE #-} +getSizeOfMutableByteArray :: MutableByteArray s -> ST s Int +getSizeOfMutableByteArray (MutableByteArray mba#) = +#if __GLASGOW_HASKELL__ >=802 + ST $ \s -> + case getSizeofMutableByteArray# mba# s of + (# s', n# #) -> (# s', I# n# #) +#else + pure $! (I# sizeofMutableByteArray# mba#) +#endif +{-# INLINE getSizeOfMutableByteArray #-} + +byteArrayToShortByteString :: ByteArray -> ShortByteString +byteArrayToShortByteString (ByteArray ba#) = SBS ba# + +-- | Convert a ShortByteString to ByteString by casting, whenever memory is pinned, +-- otherwise make a copy into a new pinned ByteString +shortByteStringToByteString :: ShortByteString -> ByteString +shortByteStringToByteString ba = +#if __GLASGOW_HASKELL__ < 802 + fromShort ba +#else + let !(SBS ba#) = ba in + if isTrue# (isByteArrayPinned# ba#) + then pinnedByteArrayToByteString ba# + else fromShort ba +{-# INLINE uniformByteStringM #-} -- | Same as 'genShortByteStringIO', but runs in 'ST'. -- -- @since 1.2.0 genShortByteStringST :: Int -> ST s Word64 -> ST s ShortByteString -genShortByteStringST n action = - unsafeIOToST (genShortByteStringIO n (unsafeSTToIO action)) +genShortByteStringST n0 action = byteArrayToShortByteString <$> genByteArrayST False n0 action {-# INLINE genShortByteStringST #-} +-- | Efficiently fills in a new `ShortByteString` in a platform independent manner. +-- +-- @since 1.2.0 +genShortByteStringIO :: + Int -- ^ Number of bytes to generate + -> IO Word64 -- ^ IO action that can generate 8 random bytes at a time + -> IO ShortByteString +genShortByteStringIO n ioAction = stToIO $ genShortByteStringST n (ioToST ioAction) +{-# INLINE genShortByteStringIO #-} + +-- | @uniformShortByteString n g@ generates a 'ShortByteString' of length @n@ +-- filled with pseudo-random bytes. +-- +-- @since 1.3.0 +uniformShortByteStringM :: StatefulGen g m => Int -> g -> m ShortByteString +uniformShortByteStringM n g = byteArrayToShortByteString <$> uniformByteArrayM False n g +{-# INLINE uniformShortByteStringM #-} -- | Generates a pseudo-random 'ByteString' of the specified size. -- -- @since 1.2.0 uniformByteStringM :: StatefulGen g m => Int -> g -> m ByteString -uniformByteStringM n g = do - ba <- uniformShortByteString n g - pure $ -#if __GLASGOW_HASKELL__ < 802 - fromShort ba -#else - let !(SBS ba#) = ba in - if isTrue# (isByteArrayPinned# ba#) - then pinnedByteArrayToByteString ba# - else fromShort ba -{-# INLINE uniformByteStringM #-} +uniformByteStringM n g = + shortByteStringToByteString . byteArrayToShortByteString + <$> uniformByteArrayM True n g + pinnedByteArrayToByteString :: ByteArray# -> ByteString pinnedByteArrayToByteString ba# = diff --git a/src/System/Random/Stateful.hs b/src/System/Random/Stateful.hs index 70113f5a..8688819f 100644 --- a/src/System/Random/Stateful.hs +++ b/src/System/Random/Stateful.hs @@ -28,7 +28,15 @@ module System.Random.Stateful -- * Mutable pseudo-random number generator interfaces -- $interfaces - , StatefulGen(..) + , StatefulGen + ( uniformWord32R + , uniformWord64R + , uniformWord8 + , uniformWord16 + , uniformWord32 + , uniformWord64 + , uniformShortByteString + ) , FrozenGen(..) , withMutableGen , withMutableGen_ @@ -44,7 +52,7 @@ module System.Random.Stateful -- * Monadic adapters for pure pseudo-random number generators #monadicadapters# -- $monadicadapters - -- ** Pure adapter + -- ** Pure adapter in 'StateT' , StateGen(..) , StateGenM(..) , runStateGen @@ -85,14 +93,23 @@ module System.Random.Stateful , uniformViaFiniteM , UniformRange(..) - -- * Generators for sequences of pseudo-random bytes + -- ** Generators for sequences of pseudo-random bytes + , uniformByteArrayM + , uniformByteStringM + , uniformShortByteStringM + + -- * Helper functions for createing instances + -- ** Sequences of bytes + , genByteArrayST , genShortByteStringIO , genShortByteStringST - , uniformByteStringM + , defaultUnsafeUniformFillMutableByteArray + -- ** Floating point numbers , uniformDouble01M , uniformDoublePositive01M , uniformFloat01M , uniformFloatPositive01M + -- ** Enum types , uniformEnumM , uniformEnumRM diff --git a/stack-oldish.yaml b/stack-oldish.yaml new file mode 100644 index 00000000..06ec8886 --- /dev/null +++ b/stack-oldish.yaml @@ -0,0 +1,5 @@ +resolver: lts-18.28 +packages: +- . +extra-deps: +- data-array-byte-0.1.0.1@sha256:2ef1bd3511e82ba56f7f23cd793dd2da84338a1e7c2cbea5b151417afe3baada,1989 diff --git a/stack.yaml b/stack.yaml index d15416fa..7f9dc000 100644 --- a/stack.yaml +++ b/stack.yaml @@ -1,4 +1,4 @@ -resolver: lts-18.28 +resolver: lts-21.17 packages: - . extra-deps: [] diff --git a/test/Spec.hs b/test/Spec.hs index 8868a6c4..a4281e45 100644 --- a/test/Spec.hs +++ b/test/Spec.hs @@ -125,7 +125,7 @@ byteStringSpec = let g = mkStdGen 2021 bs = [78,232,117,189,13,237,63,84,228,82,19,36,191,5,128,192] :: [Word8] forM_ [0 .. length bs - 1] $ \ n -> do - xs <- SBS.unpack <$> runStateGenT_ g (uniformShortByteString n) + xs <- SBS.unpack <$> runStateGenT_ g (uniformShortByteStringM n) xs @?= take n bs ys <- BS.unpack <$> runStateGenT_ g (uniformByteStringM n) ys @?= xs diff --git a/test/Spec/Stateful.hs b/test/Spec/Stateful.hs index 8c951d43..176cebb0 100644 --- a/test/Spec/Stateful.hs +++ b/test/Spec/Stateful.hs @@ -95,7 +95,7 @@ statefulSpecFor toIO toStdGen = , testProperty "uniformShortByteString/genShortByteString" $ forAll $ \(n', f :: f) -> let n = abs n' `mod` 1000 -- Ensure it is not too big - in matchRandomGenSpec toIO (uniformShortByteString n) (genShortByteString n) toStdGen f + in matchRandomGenSpec toIO (uniformShortByteStringM n) (genShortByteString n) toStdGen f ] ]