diff --git a/marcia/sampler.py b/marcia/sampler.py index 0e96442..14047be 100644 --- a/marcia/sampler.py +++ b/marcia/sampler.py @@ -72,30 +72,26 @@ def sampler(self,reset=False): print(f'Converged at iteration {sampler.iteration}') break old_tau = tau - return sampler else: sampler.run_mcmc(self.sampler_pos(), self.max_n, progress=True) - return sampler else: if reset: print(f'Reseting sampling from iteration: {last_iteration}') self.HDFBackend.reset(self.nwalkers, self.ndim) return self.sampler() print(f'Already completed {last_iteration} iterations') - return self.HDFBackend def get_burnin(self): - if self.converge: - tau = self.sampler().get_autocorr_time() - burnin = int(2 * np.max(tau)) - thin = int(0.5 * np.min(tau)) - else: - burnin = 50 - thin = 1 + tau = self.HDFBackend.get_autocorr_time() + burnin = int(2 * np.max(tau)) + thin = int(0.5 * np.min(tau)) return burnin, thin - def get_chain(self, getdist=False,reset=False): - sampler = self.sampler(reset=reset) + def get_chain(self, getdist=False): + sampler = self.HDFBackend + if sampler.iteration < self.max_n: + print(f'Only {sampler.iteration} iterations completed') + print(f'You should run the sampler to finsih the sampling of {self.max_n} iterations') burnin, thin = self.get_burnin() samples = sampler.get_chain(discard=burnin, thin=thin, flat=True) if getdist: