Skip to content

Commit

Permalink
fix passing seed to Sampler.from_ensemble
Browse files Browse the repository at this point in the history
  • Loading branch information
lbluque committed May 25, 2022
1 parent 44d719b commit ff5dec4
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
4 changes: 3 additions & 1 deletion smol/moca/sampler/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ def from_ensemble(
if kernel_type is None:
kernel_type = "Metropolis"

mckernel = mckernel_factory(kernel_type, ensemble, step_type, *args, **kwargs)
mckernel = mckernel_factory(
kernel_type, ensemble, step_type, seed=seed, *args, **kwargs
)
# get a trial trace to initialize sample container trace
_trace = mckernel.compute_initial_trace(np.zeros(ensemble.num_sites, dtype=int))
sample_trace = Trace(
Expand Down
4 changes: 2 additions & 2 deletions tests/test_moca/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@


@pytest.fixture(params=[1, 5])
def sampler(ensemble, request):
def sampler(ensemble, rng, request):
sampler = Sampler.from_ensemble(
ensemble, temperature=TEMPERATURE, nwalkers=request.param
ensemble, temperature=TEMPERATURE, seed=rng, nwalkers=request.param
)
# fix this additional attribute to sampler to access in gen occus for tests
sampler.num_sites = ensemble.num_sites
Expand Down

0 comments on commit ff5dec4

Please sign in to comment.