Skip to content

Commit

Permalink
Merge pull request #49 from idontgetoutmuch/improve-coverage
Browse files Browse the repository at this point in the history
UniformRange for Float and Double
  • Loading branch information
idontgetoutmuch authored and curiousleo committed May 19, 2020
2 parents a3b09d2 + 82ade70 commit 26446a1
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 40 deletions.
105 changes: 68 additions & 37 deletions System/Random.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE GHCForeignImportPrim #-}
{-# LANGUAGE UnliftedFFITypes #-}
#if __GLASGOW_HASKELL__ >= 701
{-# LANGUAGE Trustworthy #-}
#endif
Expand Down Expand Up @@ -82,27 +84,20 @@
-- (Word16, Word16)@ is a function to pull apart a 'Word32' into a
-- pair of 'Word16'):
--
-- >>> data PCGen' = PCGen' !Word64 !Word64
-- >>> newtype PCGen' = PCGen' { unPCGen :: PCGen }
--
-- >>> let stepGen' = second PCGen' . stepGen . unPCGen
--
-- >>> :{
-- instance RandomGen PCGen' where
-- genWord8 (PCGen' s i) = (z, PCGen' s' i')
-- where
-- (x, PCGen s' i') = stepGen (PCGen s i)
-- y = fst $ unBuildWord32 x
-- z = fst $ unBuildWord16 y
-- genWord16 (PCGen' s i) = (y, PCGen' s' i')
-- where
-- (x, PCGen s' i') = stepGen (PCGen s i)
-- y = fst $ unBuildWord32 x
-- genWord32 (PCGen' s i) = (x, PCGen' s' i')
-- where
-- (x, PCGen s' i') = stepGen (PCGen s i)
-- genWord64 (PCGen' s i) = (undefined, PCGen' s i)
-- where
-- (x, g) = stepGen (PCGen s i)
-- (y, PCGen s' i') = stepGen g
-- split _ = error "This PRNG is not splittable"
-- genWord8 = first fromIntegral . stepGen'
-- genWord16 = first fromIntegral . stepGen'
-- genWord32 = stepGen'
-- genWord64 g = (buildWord64 x y, g'')
-- where
-- (x, g') = stepGen' g
-- (y, g'') = stepGen' g'
-- buildWord64 w0 w1 = ((fromIntegral w1) `shiftL` 32) .|. (fromIntegral w0)
-- :}
--
-- [/Example for RNG Users:/]
Expand Down Expand Up @@ -219,10 +214,13 @@ import Foreign.C.Types
import Foreign.Marshal.Alloc (alloca)
import Foreign.Ptr (plusPtr)
import Foreign.Storable (peekByteOff, pokeByteOff)
import GHC.Exts (Ptr(..), build)
import GHC.Exts (Ptr(..))
import GHC.ForeignPtr
import System.IO.Unsafe (unsafePerformIO)
import qualified System.Random.SplitMix as SM
import GHC.Base
import GHC.Word


#if !MIN_VERSION_primitive(0,7,0)
import Data.Primitive.Types (Addr(..))
Expand All @@ -237,25 +235,13 @@ mutableByteArrayContentsCompat :: MutableByteArray s -> Ptr Word8
{-# INLINE mutableByteArrayContentsCompat #-}

-- $setup
-- >>> import Control.Arrow (first, second)
-- >>> import Control.Monad (replicateM)
-- >>> import Data.Bits
-- >>> import Data.Word
-- >>> import System.IO (IOMode(WriteMode), hPutStr, withBinaryFile)
-- >>> :set -XFlexibleContexts
-- >>> :set -fno-warn-missing-methods
-- >>> :{
-- unBuildWord32 :: Word32 -> (Word16, Word16)
-- unBuildWord32 w = (fromIntegral (shiftR w 16),
-- fromIntegral (fromIntegral (maxBound :: Word16) .&. w))
-- :}
--
-- >>> :{
-- unBuildWord16 :: Word16 -> (Word8, Word8)
-- unBuildWord16 w = (fromIntegral (shiftR w 8),
-- fromIntegral (fromIntegral (maxBound :: Word8) .&. w))
-- :}
--
-- >>> :{
-- buildWord64 :: Word32 -> Word32 -> Word64
-- buildWord64 w0 w1 = ((fromIntegral w1) `shiftL` 32) .|. (fromIntegral w0)
-- :}

-- | The class 'RandomGen' provides a common interface to random number
-- generators.
Expand Down Expand Up @@ -1003,7 +989,40 @@ instance Random Double where
random = randomDouble
randomM = uniformR (0, 1)

instance UniformRange Double
instance UniformRange Double where
uniformR (l, h) g = do
w64 <- uniformWord64 g
let x = word64ToDoubleInUnitInterval w64
return $ (h - l) * x + l

-- | Turns a given uniformly distributed 'Word64' value into a uniformly
-- distributed 'Double' value in the range [0, 1).
word64ToDoubleInUnitInterval :: Word64 -> Double
word64ToDoubleInUnitInterval w64 = between1and2 - 1.0
where
between1and2 = castWord64ToDouble $ (w64 `unsafeShiftR` 12) .|. 0x3ff0000000000000
{-# INLINE word64ToDoubleInUnitInterval #-}

-- | These are now in 'GHC.Float' but unpatched in some versions so
-- for now we roll our own. See
-- https://gitlab.haskell.org/ghc/ghc/-/blob/6d172e63f3dd3590b0a57371efb8f924f1fcdf05/libraries/base/GHC/Float.hs
{-# INLINE castWord32ToFloat #-}
castWord32ToFloat :: Word32 -> Float
castWord32ToFloat (W32# w#) = F# (stgWord32ToFloat w#)

foreign import prim "stg_word32ToFloatyg"
stgWord32ToFloat :: Word# -> Float#

{-# INLINE castWord64ToDouble #-}
castWord64ToDouble :: Word64 -> Double
castWord64ToDouble (W64# w) = D# (stgWord64ToDouble w)

foreign import prim "stg_word64ToDoubleyg"
#if WORD_SIZE_IN_BITS == 64
stgWord64ToDouble :: Word# -> Double#
#else
stgWord64ToDouble :: Word64# -> Double#
#endif

randomDouble :: RandomGen b => b -> (Double, b)
randomDouble rng =
Expand All @@ -1022,7 +1041,19 @@ instance Random Float where
random = randomFloat
randomM = uniformR (0, 1)

instance UniformRange Float
instance UniformRange Float where
uniformR (l, h) g = do
w32 <- uniformWord32 g
let x = word32ToFloatInUnitInterval w32
return $ (h - l) * x + l

-- | Turns a given uniformly distributed 'Word32' value into a uniformly
-- distributed 'Float' value in the range [0,1).
word32ToFloatInUnitInterval :: Word32 -> Float
word32ToFloatInUnitInterval w32 = between1and2 - 1.0
where
between1and2 = castWord32ToFloat $ (w32 `unsafeShiftR` 9) .|. 0x3f800000
{-# INLINE word32ToFloatInUnitInterval #-}

randomFloat :: RandomGen b => b -> (Float, b)
randomFloat rng =
Expand Down
46 changes: 46 additions & 0 deletions cbits/CastFloatWord.cmm
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/* From: https://gitlab.haskell.org/ghc/ghc/-/blob/6d172e63f3dd3590b0a57371efb8f924f1fcdf05/libraries/base/cbits/CastFloatWord.cmm */
#include "Cmm.h"
#include "MachDeps.h"

#if WORD_SIZE_IN_BITS == 64
#define DOUBLE_SIZE_WDS 1
#else
#define DOUBLE_SIZE_WDS 2
#endif

#if SIZEOF_W == 4
#define TO_ZXW_(x) %zx32(x)
#elif SIZEOF_W == 8
#define TO_ZXW_(x) %zx64(x)
#endif

stg_word64ToDoubleyg(I64 w)
{
D_ d;
P_ ptr;

STK_CHK_GEN_N (DOUBLE_SIZE_WDS);

reserve DOUBLE_SIZE_WDS = ptr {
I64[ptr] = w;
d = D_[ptr];
}

return (d);
}

stg_word32ToFloatyg(W_ w)
{
F_ f;
P_ ptr;

STK_CHK_GEN_N (1);

reserve 1 = ptr {
I32[ptr] = %lobits32(w);
f = F_[ptr];
}

return (f);
}

1 change: 1 addition & 0 deletions random.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ library
primitive >= 0.6.4.0 && <8,
mtl -any,
splitmix -any
c-sources: cbits/CastFloatWord.cmm

test-suite legacy
type: exitcode-stdio-1.0
Expand Down
12 changes: 11 additions & 1 deletion tests/Spec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import qualified Spec.Bitmask as Range
main :: IO ()
main = defaultMain $ testGroup "Spec"
[ bitmaskSpecWord32, bitmaskSpecWord64
, rangeSpecWord32, rangeSpecInt
, rangeSpecWord32, rangeSpecDouble, rangeSpecFloat, rangeSpecInt
]

bitmaskSpecWord32 :: TestTree
Expand All @@ -38,6 +38,16 @@ rangeSpecWord32 = testGroup "uniformR (Word32)"
, SC.testProperty "(Word32) singleton" $ seeded $ Range.singleton @StdGen @Word32
]

rangeSpecDouble :: TestTree
rangeSpecDouble = testGroup "uniformR (Double)"
[ SC.testProperty "(Double) uniform bounded" $ seeded $ Range.uniformBounded @StdGen @Double
]

rangeSpecFloat :: TestTree
rangeSpecFloat = testGroup "uniformR (Float)"
[ SC.testProperty "(Float) uniform bounded" $ seeded $ Range.uniformBounded @StdGen @Float
]

rangeSpecInt :: TestTree
rangeSpecInt = testGroup "uniformR (Int)"
[ SC.testProperty "(Int) symmetric" $ seeded $ Range.symmetric @StdGen @Int
Expand Down
5 changes: 4 additions & 1 deletion tests/Spec/Bitmask.hs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module Spec.Bitmask (symmetric, bounded, singleton) where
module Spec.Bitmask (symmetric, bounded, singleton, uniformBounded) where

import Data.Bits
import System.Random
Expand All @@ -17,3 +17,6 @@ singleton :: (RandomGen g, FiniteBits a, Num a, Ord a, Random a) => g -> a -> Bo
singleton g x = result == x
where
result = fst (bitmaskWithRejection (x, x) g)

uniformBounded :: (RandomGen g, UniformRange a, Ord a) => g -> (a, a) -> Bool
uniformBounded g (l, r) = runGenState_ g (\g -> (uniformR (l, r) g >>= \result -> return ((min l r) <= result && result <= (max l r))))
4 changes: 3 additions & 1 deletion tests/doctests.hs
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,6 @@ main = do
traverse_ putStrLn args
doctest args
where
args = flags ++ pkgs ++ module_sources
-- '-fobject-code' is required to get the doctests to build without
-- tripping over the Cmm bits.
args = ["-fobject-code"] ++ flags ++ pkgs ++ module_sources

0 comments on commit 26446a1

Please sign in to comment.