Skip to content

Commit

Permalink
ENH: Add version to state
Browse files Browse the repository at this point in the history
Add version to state as a first step to a versioned RandomState
  • Loading branch information
bashtage committed May 24, 2016
1 parent e83c366 commit 6697448
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
14 changes: 12 additions & 2 deletions randomstate/randomstate.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ cdef class RandomState:
__MAXSIZE = <uint64_t>sys.maxsize
cdef object __seed
cdef object __stream
cdef object __version

IF RS_RNG_SEED==1:
def __init__(self, seed=None):
Expand All @@ -176,6 +177,8 @@ cdef class RandomState:
IF RS_RNG_MOD_NAME == 'dsfmt':
self.rng_state.buffered_uniforms = <double *>PyArray_malloc_aligned(2 * DSFMT_N * sizeof(double))
self.lock = Lock()
self.__version = 0

self.__seed = seed
self.__stream = None

Expand All @@ -186,6 +189,8 @@ cdef class RandomState:
self.rng_state.rng = <rng_t *>PyArray_malloc_aligned(sizeof(rng_t))
self.rng_state.binomial = &self.binomial_info
self.lock = Lock()
self.__version = 0

self.__seed = seed
self.__stream = stream

Expand Down Expand Up @@ -444,7 +449,8 @@ cdef class RandomState:
'state': _get_state(self.rng_state),
'gauss': {'has_gauss': self.rng_state.has_gauss, 'gauss': self.rng_state.gauss},
'uint32': {'has_uint32': self.rng_state.has_uint32, 'uint32': self.rng_state.uinteger},
'seed': self.__seed}
'seed': self.__seed,
'version': self.__version}
if self.__stream is not None:
state['stream'] = self.__stream
return state
Expand Down Expand Up @@ -485,7 +491,8 @@ cdef class RandomState:
'state': _get_state(self.rng_state),
'gauss': {'has_gauss': self.rng_state.has_gauss, 'gauss': self.rng_state.gauss},
'uint32': {'has_uint32': self.rng_state.has_uint32, 'uint32': self.rng_state.uinteger},
'seed': self.__seed}
'seed': self.__seed,
'version': self.__version}
if self.__stream is not None:
state['stream'] = self.__stream
return state
Expand Down Expand Up @@ -545,6 +552,9 @@ cdef class RandomState:

if state['name'] != rng_name:
raise ValueError('Not a ' + rng_name + ' RNG state')
if 'version' in state:
if state['version'] != 0:
raise NotImplementedError('Support for multiple version has not been implemented.')

_set_state(&self.rng_state, state['state'])
self.rng_state.has_gauss = state['gauss']['has_gauss']
Expand Down
5 changes: 5 additions & 0 deletions randomstate/tests/test_smoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,11 @@ def test_pickle(self):
print(unpick.get_state())
assert_(comp_state(self.rs.get_state(), unpick.get_state()))

def test_version(self):
state = self.rs.get_state()
assert_('version' in state)
assert_(state['version'] == 0)

def test_seed_array(self):
if self.seed_vector_bits is None:
raise SkipTest
Expand Down

0 comments on commit 6697448

Please sign in to comment.