diff --git a/randomstate/randomstate.pyx b/randomstate/randomstate.pyx index 6d1100f0..5b83ee2a 100644 --- a/randomstate/randomstate.pyx +++ b/randomstate/randomstate.pyx @@ -168,6 +168,7 @@ cdef class RandomState: __MAXSIZE = sys.maxsize cdef object __seed cdef object __stream + cdef object __version IF RS_RNG_SEED==1: def __init__(self, seed=None): @@ -176,6 +177,8 @@ cdef class RandomState: IF RS_RNG_MOD_NAME == 'dsfmt': self.rng_state.buffered_uniforms = PyArray_malloc_aligned(2 * DSFMT_N * sizeof(double)) self.lock = Lock() + self.__version = 0 + self.__seed = seed self.__stream = None @@ -186,6 +189,8 @@ cdef class RandomState: self.rng_state.rng = PyArray_malloc_aligned(sizeof(rng_t)) self.rng_state.binomial = &self.binomial_info self.lock = Lock() + self.__version = 0 + self.__seed = seed self.__stream = stream @@ -444,7 +449,8 @@ cdef class RandomState: 'state': _get_state(self.rng_state), 'gauss': {'has_gauss': self.rng_state.has_gauss, 'gauss': self.rng_state.gauss}, 'uint32': {'has_uint32': self.rng_state.has_uint32, 'uint32': self.rng_state.uinteger}, - 'seed': self.__seed} + 'seed': self.__seed, + 'version': self.__version} if self.__stream is not None: state['stream'] = self.__stream return state @@ -485,7 +491,8 @@ cdef class RandomState: 'state': _get_state(self.rng_state), 'gauss': {'has_gauss': self.rng_state.has_gauss, 'gauss': self.rng_state.gauss}, 'uint32': {'has_uint32': self.rng_state.has_uint32, 'uint32': self.rng_state.uinteger}, - 'seed': self.__seed} + 'seed': self.__seed, + 'version': self.__version} if self.__stream is not None: state['stream'] = self.__stream return state @@ -545,6 +552,9 @@ cdef class RandomState: if state['name'] != rng_name: raise ValueError('Not a ' + rng_name + ' RNG state') + if 'version' in state: + if state['version'] != 0: + raise NotImplementedError('Support for multiple version has not been implemented.') _set_state(&self.rng_state, state['state']) self.rng_state.has_gauss = state['gauss']['has_gauss'] diff --git a/randomstate/tests/test_smoke.py b/randomstate/tests/test_smoke.py index e40c21d4..7c1571b5 100644 --- a/randomstate/tests/test_smoke.py +++ b/randomstate/tests/test_smoke.py @@ -445,6 +445,11 @@ def test_pickle(self): print(unpick.get_state()) assert_(comp_state(self.rs.get_state(), unpick.get_state())) + def test_version(self): + state = self.rs.get_state() + assert_('version' in state) + assert_(state['version'] == 0) + def test_seed_array(self): if self.seed_vector_bits is None: raise SkipTest