diff --git a/pyscf/mcscf/chkfile.py b/pyscf/mcscf/chkfile.py index 826f10ddbe..73fa255808 100644 --- a/pyscf/mcscf/chkfile.py +++ b/pyscf/mcscf/chkfile.py @@ -59,8 +59,6 @@ def dump_mcscf( mo_coeff = mc.mo_coeff if mo_occ is None: mo_occ = mc.mo_occ - if mo_energy is None: - mo_energy = mc.mo_energy # if ci_vector is None: ci_vector = mc.ci if h5py.is_hdf5(chkfile): @@ -93,7 +91,8 @@ def store(subkey, val): store("ncore", ncore) store("ncas", ncas) store("mo_occ", mo_occ) - store("mo_energy", mo_energy) + if mo_energy is not None: + store("mo_energy", mo_energy) store("casdm1", casdm1) if not mixed_ci: diff --git a/pyscf/mcscf/test/test_umc1step.py b/pyscf/mcscf/test/test_umc1step.py index ab47b23e0d..afbca92483 100644 --- a/pyscf/mcscf/test/test_umc1step.py +++ b/pyscf/mcscf/test/test_umc1step.py @@ -14,6 +14,7 @@ # limitations under the License. import unittest +import tempfile import numpy from pyscf import lib from pyscf import gto @@ -46,6 +47,13 @@ def tearDownModule(): class KnownValues(unittest.TestCase): + def test_ucasscf(self): + with tempfile.NamedTemporaryFile() as f: + mc = mcscf.UCASSCF(m, 4, 4) + mc.chkfile = f.name + mc.run() + self.assertAlmostEqual(mc.e_tot, -75.7460662487894, 6) + def test_with_x2c_scanner(self): mc1 = mcscf.UCASSCF(m, 4, 4).x2c().run() self.assertAlmostEqual(mc1.e_tot, -75.795316854668201, 6) diff --git a/pyscf/mcscf/umc1step.py b/pyscf/mcscf/umc1step.py index 777799028b..00b908166f 100644 --- a/pyscf/mcscf/umc1step.py +++ b/pyscf/mcscf/umc1step.py @@ -327,7 +327,7 @@ def kernel(casscf, mo_coeff, tol=1e-7, conv_tol_grad=None, and (norm_gorb0 < conv_tol_grad and norm_ddm < conv_tol_ddm)): conv = True - if dump_chk: + if dump_chk and casscf.chkfile: casscf.dump_chk(locals()) if callable(callback): @@ -383,8 +383,7 @@ class UCASSCF(ucasci.UCASBase): def __init__(self, mf_or_mol, ncas=0, nelecas=0, ncore=None, frozen=None): ucasci.UCASBase.__init__(self, mf_or_mol, ncas, nelecas, ncore) self.frozen = frozen - - self.chkfile = self._scf.chkfile + self.chkfile = None self.fcisolver.max_cycle = getattr(__config__, 'mcscf_umc1step_UCASSCF_fcisolver_max_cycle', 50) @@ -764,6 +763,9 @@ def dump_chk(self, envs_or_file): if envs is not None: if self.chk_ci: civec = envs['fcivec'] + e_tot = envs['e_tot'] + e_cas = envs['e_cas'] + casdm1 = envs['casdm1'] if 'mo' in envs: mo_coeff = envs['mo'] else: