Skip to content

Commit

Permalink
Update mf type check code in all post HF functions
Browse files Browse the repository at this point in the history
  • Loading branch information
sunqm committed Oct 26, 2023
1 parent 2fc341a commit 7860b32
Show file tree
Hide file tree
Showing 27 changed files with 160 additions and 96 deletions.
14 changes: 9 additions & 5 deletions pyscf/adc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def ADC(mf, frozen=None, mo_coeff=None, mo_occ=None):
#elif isinstance(mf, scf.rohf.ROHF):
# lib.logger.warn(mf, 'RADC method does not support ROHF reference. ROHF object '
# 'is converted to UHF object and UADC method is called.')
# mf = scf.addons.convert_to_uhf(mf)
# mf = mf.to_uhf(mf)
# return UADC(mf, frozen, mo_coeff, mo_occ)
# TODO add ROHF functionality
elif isinstance(mf, scf.rhf.RHF):
Expand All @@ -64,8 +64,10 @@ def UADC(mf, frozen=None, mo_coeff=None, mo_occ=None):

from pyscf.soscf import newton_ah

if isinstance(mf, newton_ah._CIAH_SOSCF) or not isinstance(mf, scf.uhf.UHF):
mf = scf.addons.convert_to_uhf(mf)
if not isinstance(mf, scf.uhf.UHF):
mf = mf.to_uhf()
if isinstance(mf, newton_ah._CIAH_SOSCF):
mf = mf.undo_soscf()

return uadc.UADC(mf, frozen, mo_coeff, mo_occ)

Expand All @@ -79,8 +81,10 @@ def RADC(mf, frozen=None, mo_coeff=None, mo_occ=None):

from pyscf.soscf import newton_ah

if isinstance(mf, newton_ah._CIAH_SOSCF) or not isinstance(mf, scf.rhf.RHF):
mf = scf.addons.convert_to_rhf(mf)
if not isinstance(mf, scf.rhf.RHF):
mf = mf.to_rhf()
if isinstance(mf, newton_ah._CIAH_SOSCF):
mf = mf.undo_soscf()

return radc.RADC(mf, frozen, mo_coeff, mo_occ)

Expand Down
2 changes: 1 addition & 1 deletion pyscf/agf2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def AGF2(mf, nmom=(None,0), frozen=None, mo_energy=None, mo_coeff=None, mo_occ=N
elif isinstance(mf, scf.rohf.ROHF):
lib.logger.warn(mf, 'RAGF2 method does not support ROHF reference. '
'Converting to UHF and using UAGF2.')
mf = scf.addons.convert_to_uhf(mf)
mf = mf.to_uhf()
return UAGF2(mf, nmom, frozen, mo_energy, mo_coeff, mo_occ)

elif isinstance(mf, scf.rhf.RHF):
Expand Down
52 changes: 31 additions & 21 deletions pyscf/cc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@
from pyscf import scf

def CCSD(mf, frozen=None, mo_coeff=None, mo_occ=None):
if isinstance(mf, scf.uhf.UHF):
if mf.istype('UHF'):
return UCCSD(mf, frozen, mo_coeff, mo_occ)
elif isinstance(mf, scf.ghf.GHF):
elif mf.istype('GHF'):
return GCCSD(mf, frozen, mo_coeff, mo_occ)
else:
return RCCSD(mf, frozen, mo_coeff, mo_occ)
Expand All @@ -93,21 +93,24 @@ def CCSD(mf, frozen=None, mo_coeff=None, mo_occ=None):
def RCCSD(mf, frozen=None, mo_coeff=None, mo_occ=None):
import numpy
from pyscf import lib
from pyscf.df.df_jk import _DFHF
from pyscf.soscf import newton_ah
from pyscf.cc import dfccsd

if isinstance(mf, scf.uhf.UHF):
if mf.istype('UHF'):
raise RuntimeError('RCCSD cannot be used with UHF method.')
elif isinstance(mf, scf.rohf.ROHF):
elif mf.istype('ROHF'):
lib.logger.warn(mf, 'RCCSD method does not support ROHF method. ROHF object '
'is converted to UHF object and UCCSD method is called.')
mf = scf.addons.convert_to_uhf(mf)
mf = mf.to_uhf()
return UCCSD(mf, frozen, mo_coeff, mo_occ)

if isinstance(mf, newton_ah._CIAH_SOSCF) or not isinstance(mf, scf.hf.RHF):
mf = scf.addons.convert_to_rhf(mf)
if not mf.istype('RHF'):
mf = mf.to_rhf()
if isinstance(mf, newton_ah._CIAH_SOSCF):
mf = mf.undo_soscf()

if getattr(mf, 'with_df', None):
if isinstance(mf, _DFHF) and mf.with_df:
return dfccsd.RCCSD(mf, frozen, mo_coeff, mo_occ)

elif numpy.iscomplexobj(mo_coeff) or numpy.iscomplexobj(mf.mo_coeff):
Expand All @@ -119,12 +122,15 @@ def RCCSD(mf, frozen=None, mo_coeff=None, mo_occ=None):


def UCCSD(mf, frozen=None, mo_coeff=None, mo_occ=None):
from pyscf.df.df_jk import _DFHF
from pyscf.soscf import newton_ah

if isinstance(mf, newton_ah._CIAH_SOSCF) or not isinstance(mf, scf.uhf.UHF):
mf = scf.addons.convert_to_uhf(mf)
if not mf.istype('UHF'):
mf = mf.to_uhf()
if isinstance(mf, newton_ah._CIAH_SOSCF):
mf = mf.undo_soscf()

if getattr(mf, 'with_df', None):
if isinstance(mf, _DFHF) and mf.with_df:
# TODO: DF-UCCSD with memory-efficient particle-particle ladder,
# similar to dfccsd.RCCSD
return uccsd.UCCSD(mf, frozen, mo_coeff, mo_occ)
Expand All @@ -134,22 +140,25 @@ def UCCSD(mf, frozen=None, mo_coeff=None, mo_occ=None):


def GCCSD(mf, frozen=None, mo_coeff=None, mo_occ=None):
from pyscf.df.df_jk import _DFHF
from pyscf.soscf import newton_ah

if isinstance(mf, newton_ah._CIAH_SOSCF) or not isinstance(mf, scf.ghf.GHF):
mf = scf.addons.convert_to_ghf(mf)
if not mf.istype('GHF'):
mf = mf.to_ghf()
if isinstance(mf, newton_ah._CIAH_SOSCF):
mf = mf.undo_soscf()

if getattr(mf, 'with_df', None):
if isinstance(mf, _DFHF) and mf.with_df:
raise NotImplementedError('DF-GCCSD')
else:
return gccsd.GCCSD(mf, frozen, mo_coeff, mo_occ)
GCCSD.__doc__ = gccsd.GCCSD.__doc__


def QCISD(mf, frozen=None, mo_coeff=None, mo_occ=None):
if isinstance(mf, scf.uhf.UHF):
if mf.istype('UHF'):
raise NotImplementedError
elif isinstance(mf, scf.ghf.GHF):
elif mf.istype('GHF'):
raise NotImplementedError
else:
return RQCISD(mf, frozen, mo_coeff, mo_occ)
Expand All @@ -162,16 +171,17 @@ def RQCISD(mf, frozen=None, mo_coeff=None, mo_occ=None):
from pyscf import lib
from pyscf.soscf import newton_ah

if isinstance(mf, scf.uhf.UHF):
if mf.istype('UHF'):
raise RuntimeError('RQCISD cannot be used with UHF method.')
elif isinstance(mf, scf.rohf.ROHF):
elif mf.istype('ROHF'):
lib.logger.warn(mf, 'RQCISD method does not support ROHF method. ROHF object '
'is converted to UHF object and UQCISD method is called.')
mf = scf.addons.convert_to_uhf(mf)
raise NotImplementedError

if isinstance(mf, newton_ah._CIAH_SOSCF) or not isinstance(mf, scf.hf.RHF):
mf = scf.addons.convert_to_rhf(mf)
if not mf.istype('RHF'):
mf = mf.to_rhf()
if isinstance(mf, newton_ah._CIAH_SOSCF):
mf = mf.undo_soscf()

elif numpy.iscomplexobj(mo_coeff) or numpy.iscomplexobj(mf.mo_coeff):
raise NotImplementedError
Expand Down
6 changes: 2 additions & 4 deletions pyscf/cc/addons.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,13 @@ def spin2spatial(tx, orbspin):
return t2aa,t2ab,t2bb

def convert_to_uccsd(mycc):
from pyscf import scf
from pyscf.cc import uccsd, gccsd
if isinstance(mycc, uccsd.UCCSD):
return mycc
elif isinstance(mycc, gccsd.GCCSD):
raise NotImplementedError

mf = scf.addons.convert_to_uhf(mycc._scf)
mf = mycc._scf.to_uhf()
ucc = uccsd.UCCSD(mf)
assert (mycc._nocc is None)
assert (mycc._nmo is None)
Expand All @@ -143,12 +142,11 @@ def convert_to_uccsd(mycc):
return ucc

def convert_to_gccsd(mycc):
from pyscf import scf
from pyscf.cc import gccsd
if isinstance(mycc, gccsd.GCCSD):
return mycc

mf = scf.addons.convert_to_ghf(mycc._scf)
mf = mycc._scf.to_ghf()
gcc = gccsd.GCCSD(mf)
assert (mycc._nocc is None)
assert (mycc._nmo is None)
Expand Down
2 changes: 2 additions & 0 deletions pyscf/cc/ccsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1496,6 +1496,8 @@ def _make_eris_incore(mycc, mo_coeff=None):
return eris

def _make_eris_outcore(mycc, mo_coeff=None):
from pyscf.scf.hf import RHF
assert isinstance(mycc._scf, RHF)
cput0 = (logger.process_clock(), logger.perf_counter())
log = logger.Logger(mycc.stdout, mycc.verbose)
eris = _ChemistsERIs()
Expand Down
1 change: 1 addition & 0 deletions pyscf/cc/dfccsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def _contract_vvvv_t2(self, mycc, t2, direct=False, out=None, verbose=None):
return _contract_vvvv_t2(mycc, self.mol, self.vvL, t2, out, verbose)

def _make_df_eris(cc, mo_coeff=None):
assert cc._scf.istype('RHF')
eris = _ChemistsERIs()
eris._common_init_(cc, mo_coeff)
nocc = eris.nocc
Expand Down
3 changes: 2 additions & 1 deletion pyscf/cc/gccsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ class GCCSD(ccsd.CCSDBase):
conv_tol_normt = getattr(__config__, 'cc_gccsd_GCCSD_conv_tol_normt', 1e-6)

def __init__(self, mf, frozen=None, mo_coeff=None, mo_occ=None):
assert (isinstance(mf, scf.ghf.GHF))
ccsd.CCSDBase.__init__(self, mf, frozen, mo_coeff, mo_occ)

def init_amps(self, eris=None):
Expand Down Expand Up @@ -385,6 +384,8 @@ def _make_eris_incore(mycc, mo_coeff=None, ao2mofn=None):
return eris

def _make_eris_outcore(mycc, mo_coeff=None):
from pyscf.scf.ghf import GHF
assert isinstance(mycc._scf, GHF)
cput0 = (logger.process_clock(), logger.perf_counter())
log = logger.Logger(mycc.stdout, mycc.verbose)

Expand Down
4 changes: 2 additions & 2 deletions pyscf/cc/gccsd_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,13 @@ def get_wv(a, b, c):
mycc = cc.CCSD(mf).set(conv_tol=1e-11).run()
et = mycc.ccsd_t()

mycc = cc.GCCSD(scf.addons.convert_to_ghf(mf)).set(conv_tol=1e-11).run()
mycc = cc.GCCSD(mf.to_ghf()).set(conv_tol=1e-11).run()
eris = mycc.ao2mo()
print(kernel(mycc, eris) - et)

numpy.random.seed(1)
mf.mo_coeff = numpy.random.random(mf.mo_coeff.shape) - .9
mycc = cc.GCCSD(scf.addons.convert_to_ghf(mf))
mycc = cc.GCCSD(scf.addons.convert_to_ghf())
eris = mycc.ao2mo()
nocc = 10
nvir = mol.nao_nr() * 2 - nocc
Expand Down
2 changes: 1 addition & 1 deletion pyscf/cc/gccsd_t_rdm.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def make_rdm2(mycc, t1, t2, l1, l2, eris=None):
mol.basis = '631g'
mol.build()
mf0 = mf = scf.RHF(mol).run(conv_tol=1.)
mf = scf.addons.convert_to_ghf(mf)
mf = mf.to_ghf()

from pyscf.cc import ccsd_t_lambda_slow as ccsd_t_lambda
from pyscf.cc import ccsd_t_rdm_slow as ccsd_t_rdm
Expand Down
2 changes: 2 additions & 0 deletions pyscf/cc/rccsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,8 @@ def _make_eris_incore(mycc, mo_coeff=None, ao2mofn=None):
return eris

def _make_eris_outcore(mycc, mo_coeff=None):
from pyscf.scf.hf import RHF
assert isinstance(mycc._scf, RHF)
cput0 = (logger.process_clock(), logger.perf_counter())
log = logger.Logger(mycc.stdout, mycc.verbose)
eris = _ChemistsERIs()
Expand Down
5 changes: 4 additions & 1 deletion pyscf/cc/uccsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ class UCCSD(ccsd.CCSDBase):
# * A pair of list : First list is the orbital indices to be frozen for alpha
# orbitals, second list is for beta orbitals
def __init__(self, mf, frozen=None, mo_coeff=None, mo_occ=None):
assert isinstance(mf, scf.uhf.UHF)
assert mf.istype('UHF')
ccsd.CCSDBase.__init__(self, mf, frozen, mo_coeff, mo_occ)

get_nocc = get_nocc
Expand Down Expand Up @@ -942,6 +942,7 @@ def _make_eris_incore(mycc, mo_coeff=None, ao2mofn=None):
return eris

def _make_df_eris_outcore(mycc, mo_coeff=None):
assert mycc._scf.istype('UHF')
cput0 = (logger.process_clock(), logger.perf_counter())
log = logger.Logger(mycc.stdout, mycc.verbose)
eris = _ChemistsERIs()
Expand Down Expand Up @@ -1059,6 +1060,8 @@ def _make_df_eris_outcore(mycc, mo_coeff=None):
return eris

def _make_eris_outcore(mycc, mo_coeff=None):
from pyscf.scf.uhf import UHF
assert isinstance(mycc._scf, UHF)
eris = _ChemistsERIs()
eris._common_init_(mycc, mo_coeff)

Expand Down
52 changes: 33 additions & 19 deletions pyscf/ci/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from pyscf import lib
from pyscf import scf
from pyscf.ci import cisd
from pyscf.ci import ucisd
Expand All @@ -23,9 +22,10 @@
def CISD(mf, frozen=None, mo_coeff=None, mo_occ=None):
from pyscf.soscf import newton_ah

if isinstance(mf, scf.uhf.UHF):
if mf.istype('UHF'):
return UCISD(mf, frozen, mo_coeff, mo_occ)
elif isinstance(mf, scf.rohf.ROHF):
elif mf.istype('ROHF'):
from pyscf import lib
lib.logger.warn(mf, 'RCISD method does not support ROHF method. ROHF object '
'is converted to UHF object and UCISD method is called.')
return UCISD(mf, frozen, mo_coeff, mo_occ)
Expand All @@ -34,12 +34,18 @@ def CISD(mf, frozen=None, mo_coeff=None, mo_occ=None):
CISD.__doc__ = cisd.CISD.__doc__

def RCISD(mf, frozen=None, mo_coeff=None, mo_occ=None):
from pyscf.df.df_jk import _DFHF
from pyscf.soscf import newton_ah

if isinstance(mf, newton_ah._CIAH_SOSCF) or not isinstance(mf, scf.hf.RHF):
mf = scf.addons.convert_to_rhf(mf)
if not mf.istype('RHF'):
mf = mf.to_rhf()
if isinstance(mf, newton_ah._CIAH_SOSCF):
mf = mf.undo_soscf()

if getattr(mf, 'with_df', None):
if isinstance(mf, _DFHF) and mf.with_df:
from pyscf import lib
lib.logger.warn(mf, f'DF-RCISD for DFHF method {mf} is not available. '
'Normal RCISD method is called.')
return cisd.RCISD(mf, frozen, mo_coeff, mo_occ)
else:
return cisd.RCISD(mf, frozen, mo_coeff, mo_occ)
Expand All @@ -48,10 +54,15 @@ def RCISD(mf, frozen=None, mo_coeff=None, mo_occ=None):
def UCISD(mf, frozen=None, mo_coeff=None, mo_occ=None):
from pyscf.soscf import newton_ah

if isinstance(mf, newton_ah._CIAH_SOSCF) or not isinstance(mf, scf.uhf.UHF):
mf = scf.addons.convert_to_uhf(mf)
if not mf.istype('UHF'):
mf = mf.to_uhf()
if isinstance(mf, newton_ah._CIAH_SOSCF):
mf = mf.undo_soscf()

if getattr(mf, 'with_df', None):
if isinstance(mf, _DFHF) and mf.with_df:
from pyscf import lib
lib.logger.warn(mf, f'DF-UCISD for DFHF method {mf} is not available. '
'Normal UCISD method is called.')
return ucisd.UCISD(mf, frozen, mo_coeff, mo_occ)
else:
return ucisd.UCISD(mf, frozen, mo_coeff, mo_occ)
Expand All @@ -61,20 +72,22 @@ def UCISD(mf, frozen=None, mo_coeff=None, mo_occ=None):
def GCISD(mf, frozen=None, mo_coeff=None, mo_occ=None):
from pyscf.soscf import newton_ah

if isinstance(mf, newton_ah._CIAH_SOSCF) or not isinstance(mf, scf.ghf.GHF):
mf = scf.addons.convert_to_ghf(mf)
if not mf.istype('GHF'):
mf = mf.to_ghf()
if isinstance(mf, newton_ah._CIAH_SOSCF):
mf = mf.undo_soscf()

if getattr(mf, 'with_df', None):
if isinstance(mf, _DFHF) and mf.with_df:
raise NotImplementedError('DF-GCISD')
else:
return gcisd.GCISD(mf, frozen, mo_coeff, mo_occ)
GCISD.__doc__ = gcisd.GCISD.__doc__


def QCISD(mf, frozen=None, mo_coeff=None, mo_occ=None):
if isinstance(mf, scf.uhf.UHF):
if mf.istype('UHF'):
raise NotImplementedError
elif isinstance(mf, scf.ghf.GHF):
elif mf.istype('GHF'):
raise NotImplementedError
else:
return RQCISD(mf, frozen, mo_coeff, mo_occ)
Expand All @@ -87,16 +100,17 @@ def RQCISD(mf, frozen=None, mo_coeff=None, mo_occ=None):
from pyscf import lib
from pyscf.soscf import newton_ah

if isinstance(mf, scf.uhf.UHF):
if mf.istype('UHF'):
raise RuntimeError('RQCISD cannot be used with UHF method.')
elif isinstance(mf, scf.rohf.ROHF):
elif mf.istype('ROHF'):
lib.logger.warn(mf, 'RQCISD method does not support ROHF method. ROHF object '
'is converted to UHF object and UQCISD method is called.')
mf = scf.addons.convert_to_uhf(mf)
raise NotImplementedError

if isinstance(mf, newton_ah._CIAH_SOSCF) or not isinstance(mf, scf.hf.RHF):
mf = scf.addons.convert_to_rhf(mf)
if not mf.istype('RHF'):
mf = mf.to_rhf()
if isinstance(mf, newton_ah._CIAH_SOSCF):
mf = mf.undo_soscf()

elif numpy.iscomplexobj(mo_coeff) or numpy.iscomplexobj(mf.mo_coeff):
raise NotImplementedError
Expand Down
Loading

0 comments on commit 7860b32

Please sign in to comment.