diff --git a/smol/moca/sampler/sampler.py b/smol/moca/sampler/sampler.py index 18725f2b1..00d982a23 100644 --- a/smol/moca/sampler/sampler.py +++ b/smol/moca/sampler/sampler.py @@ -95,7 +95,9 @@ def from_ensemble( if kernel_type is None: kernel_type = "Metropolis" - mckernel = mckernel_factory(kernel_type, ensemble, step_type, *args, **kwargs) + mckernel = mckernel_factory( + kernel_type, ensemble, step_type, seed=seed, *args, **kwargs + ) # get a trial trace to initialize sample container trace _trace = mckernel.compute_initial_trace(np.zeros(ensemble.num_sites, dtype=int)) sample_trace = Trace( diff --git a/tests/test_moca/test_sampler.py b/tests/test_moca/test_sampler.py index 5549e57e9..455e6ac3d 100644 --- a/tests/test_moca/test_sampler.py +++ b/tests/test_moca/test_sampler.py @@ -11,9 +11,9 @@ @pytest.fixture(params=[1, 5]) -def sampler(ensemble, request): +def sampler(ensemble, rng, request): sampler = Sampler.from_ensemble( - ensemble, temperature=TEMPERATURE, nwalkers=request.param + ensemble, temperature=TEMPERATURE, seed=rng, nwalkers=request.param ) # fix this additional attribute to sampler to access in gen occus for tests sampler.num_sites = ensemble.num_sites