Skip to content

Commit

Permalink
Fix dump_chk in UCASSCF (fix pyscf#2432)
Browse files Browse the repository at this point in the history
  • Loading branch information
sunqm committed Nov 4, 2024
1 parent e66bb81 commit 2386fbf
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 6 deletions.
5 changes: 2 additions & 3 deletions pyscf/mcscf/chkfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,6 @@ def dump_mcscf(
mo_coeff = mc.mo_coeff
if mo_occ is None:
mo_occ = mc.mo_occ
if mo_energy is None:
mo_energy = mc.mo_energy
# if ci_vector is None: ci_vector = mc.ci

if h5py.is_hdf5(chkfile):
Expand Down Expand Up @@ -93,7 +91,8 @@ def store(subkey, val):
store("ncore", ncore)
store("ncas", ncas)
store("mo_occ", mo_occ)
store("mo_energy", mo_energy)
if mo_energy is not None:
store("mo_energy", mo_energy)
store("casdm1", casdm1)

if not mixed_ci:
Expand Down
8 changes: 8 additions & 0 deletions pyscf/mcscf/test/test_umc1step.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import unittest
import tempfile
import numpy
from pyscf import lib
from pyscf import gto
Expand Down Expand Up @@ -46,6 +47,13 @@ def tearDownModule():


class KnownValues(unittest.TestCase):
def test_ucasscf(self):
with tempfile.NamedTemporaryFile() as f:
mc = mcscf.UCASSCF(m, 4, 4)
mc.chkfile = f.name
mc.run()
self.assertAlmostEqual(mc.e_tot, -75.7460662487894, 6)

def test_with_x2c_scanner(self):
mc1 = mcscf.UCASSCF(m, 4, 4).x2c().run()
self.assertAlmostEqual(mc1.e_tot, -75.795316854668201, 6)
Expand Down
8 changes: 5 additions & 3 deletions pyscf/mcscf/umc1step.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def kernel(casscf, mo_coeff, tol=1e-7, conv_tol_grad=None,
and (norm_gorb0 < conv_tol_grad and norm_ddm < conv_tol_ddm)):
conv = True

if dump_chk:
if dump_chk and casscf.chkfile:
casscf.dump_chk(locals())

if callable(callback):
Expand Down Expand Up @@ -383,8 +383,7 @@ class UCASSCF(ucasci.UCASBase):
def __init__(self, mf_or_mol, ncas=0, nelecas=0, ncore=None, frozen=None):
ucasci.UCASBase.__init__(self, mf_or_mol, ncas, nelecas, ncore)
self.frozen = frozen

self.chkfile = self._scf.chkfile
self.chkfile = None

self.fcisolver.max_cycle = getattr(__config__,
'mcscf_umc1step_UCASSCF_fcisolver_max_cycle', 50)
Expand Down Expand Up @@ -764,6 +763,9 @@ def dump_chk(self, envs_or_file):
if envs is not None:
if self.chk_ci:
civec = envs['fcivec']
e_tot = envs['e_tot']
e_cas = envs['e_cas']
casdm1 = envs['casdm1']
if 'mo' in envs:
mo_coeff = envs['mo']
else:
Expand Down

0 comments on commit 2386fbf

Please sign in to comment.