Skip to content

Commit

Permalink
Add istype function to handle both the mol-SCF and gamma-point-SCF ob…
Browse files Browse the repository at this point in the history
…jects
  • Loading branch information
sunqm committed Oct 26, 2023
1 parent fde229f commit 2fc341a
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 56 deletions.
24 changes: 11 additions & 13 deletions pyscf/pbc/scf/addons.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,13 @@ def get_occ(self, mo_energy_kpts=None, mo_coeff_kpts=None):
This is a k-point version of scf.hf.SCF.get_occ
'''
from pyscf.scf import uhf, rohf, ghf
if (self.sigma == 0) or (not self.sigma) or (not self.smearing_method):
mo_occ_kpts = super().get_occ(mo_energy_kpts, mo_coeff_kpts)
return mo_occ_kpts

is_uhf = isinstance(self, uhf.UHF)
is_ghf = isinstance(self, ghf.GHF)
is_rhf = (not is_uhf) and (not is_ghf)
is_rohf = isinstance(self, rohf.ROHF)
is_uhf = self.istype('KUHF')
is_rhf = self.istype('KRHF')
is_rohf = self.istype('KROHF')

sigma = self.sigma
if self.smearing_method.lower() == 'fermi':
Expand Down Expand Up @@ -196,14 +194,13 @@ def get_occ(self, mo_energy_kpts=None, mo_coeff_kpts=None):
return mo_occ_kpts

def get_grad(self, mo_coeff_kpts, mo_occ_kpts, fock=None):
from pyscf.scf import uhf
if (self.sigma == 0) or (not self.sigma) or (not self.smearing_method):
return super().get_grad(mo_coeff_kpts, mo_occ_kpts, fock)

if fock is None:
dm1 = self.make_rdm1(mo_coeff_kpts, mo_occ_kpts)
fock = self.get_hcore() + self.get_veff(self.mol, dm1)
if isinstance(self, uhf.UHF):
if self.istype('KUHF'):
ga = _get_grad_tril(mo_coeff_kpts[0], mo_occ_kpts[0], fock[0])
gb = _get_grad_tril(mo_coeff_kpts[1], mo_occ_kpts[1], fock[1])
return numpy.hstack((ga,gb))
Expand Down Expand Up @@ -288,7 +285,7 @@ def convert_to_uhf(mf, out=None):
if isinstance(mf, (scf.uhf.UHF, scf.kuhf.KUHF)):
return mf.copy()
else:
if isinstance(mf, scf.kghf.KGHF):
if isinstance(mf, (scf.ghf.GHF, scf.kghf.KGHF)):
raise NotImplementedError(
f'No conversion from {mf.__class__} to uhf object')

Expand Down Expand Up @@ -339,7 +336,7 @@ def convert_to_rhf(mf, out=None):
assert (not isinstance(out, scf.khf.KSCF))
out = mol_addons._update_mf_without_soscf(mf, out, False)

elif nelec[0] != nelec[1] and isinstance(mf, scf.rohf.ROHF):
elif nelec[0] != nelec[1] and isinstance(mf, (scf.rohf.ROHF, scf.krohf.KROHF)):
if getattr(mf, '_scf', None):
return mol_addons._update_mf_without_soscf(mf, mf._scf.copy(), False)
else:
Expand All @@ -349,7 +346,7 @@ def convert_to_rhf(mf, out=None):
if isinstance(mf, (scf.hf.RHF, scf.khf.KRHF)):
return mf.copy()
else:
if isinstance(mf, scf.kghf.KGHF):
if isinstance(mf, (scf.ghf.GHF, scf.kghf.KGHF)):
raise NotImplementedError(
f'No conversion from {mf.__class__} to rhf object')

Expand Down Expand Up @@ -397,7 +394,7 @@ def convert_to_ghf(mf, out=None):
else:
assert (not isinstance(out, scf.khf.KSCF))

if isinstance(mf, scf.ghf.GHF):
if isinstance(mf, (scf.ghf.GHF, scf.ghf.KGHF)):
if out is None:
return mf.copy()
else:
Expand All @@ -416,7 +413,7 @@ def update_mo_(mf, mf1):
nkpts = mf.kpts.nkpts_ibz
else:
nkpts = len(mf.kpts)
is_rhf = isinstance(mf, scf.hf.RHF)
is_rhf = mf.istype('KRHF')
for k in range(nkpts):
if is_rhf:
mo_a = mo_b = mf.mo_coeff[k]
Expand Down Expand Up @@ -472,6 +469,7 @@ def convert_to_kscf(mf, out=None):
'''
from pyscf.pbc import scf, dft
if not isinstance(mf, scf.khf.KSCF):
assert isinstance(mf, scf.hf.SCF)
known_cls = {
dft.uks.UKS : dft.kuks.KUKS ,
dft.roks.ROKS : dft.kroks.KROKS,
Expand All @@ -484,7 +482,7 @@ def convert_to_kscf(mf, out=None):
}
mf = mol_addons._object_without_soscf(mf, known_cls, False)
if mf.mo_energy is not None:
if isinstance(mf, scf.uhf.UHF):
if mf.istype('UHF'):
mf.mo_occ = mf.mo_occ[:,numpy.newaxis]
mf.mo_coeff = mf.mo_coeff[:,numpy.newaxis]
mf.mo_energy = mf.mo_energy[:,numpy.newaxis]
Expand Down
3 changes: 2 additions & 1 deletion pyscf/scf/_response_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def _gen_rhf_response(mf, mo_coeff=None, mo_occ=None,
orbital hessian or CPHF will be generated. If singlet is boolean,
it is used in TDDFT response kernel.
'''
assert (not isinstance(mf, (uhf.UHF, rohf.ROHF)))
assert isinstance(mf, hf.RHF) and not isinstance(mf, (uhf.UHF, rohf.ROHF))

if mo_coeff is None: mo_coeff = mf.mo_coeff
if mo_occ is None: mo_occ = mf.mo_occ
Expand Down Expand Up @@ -148,6 +148,7 @@ def _gen_uhf_response(mf, mo_coeff=None, mo_occ=None,
'''Generate a function to compute the product of UHF response function and
UHF density matrices.
'''
assert isinstance(mf, (uhf.UHF, rohf.ROHF))
if mo_coeff is None: mo_coeff = mf.mo_coeff
if mo_occ is None: mo_occ = mf.mo_occ
mol = mf.mol
Expand Down
70 changes: 33 additions & 37 deletions pyscf/scf/addons.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from pyscf.gto import mole
from pyscf.lib import logger
from pyscf.lib.scipy_helper import pivoted_cholesky
from pyscf.scf import hf
from pyscf import __config__

LINEAR_DEP_THRESHOLD = getattr(__config__, 'scf_addons_remove_linear_dep_threshold', 1e-8)
Expand All @@ -44,6 +43,7 @@ def smearing(mf, sigma=None, method=SMEARING_METHOD, mu0=None, fix_spin=False):
mf.fix_spin = fix_spin
return mf

assert not mf.istype('KSCF')
return lib.set_class(_SmearingSCF(mf, sigma, method, mu0, fix_spin),
(_SmearingSCF, mf.__class__))

Expand Down Expand Up @@ -116,16 +116,14 @@ def undo_smearing(self):
def get_occ(self, mo_energy=None, mo_coeff=None):
'''Label the occupancies for each orbital
'''
from pyscf.scf import uhf, rohf, ghf
from pyscf.pbc.tools import print_mo_energy_occ
if (self.sigma == 0) or (not self.sigma) or (not self.smearing_method):
mo_occ = super().get_occ(mo_energy, mo_coeff)
return mo_occ

is_uhf = isinstance(self, uhf.UHF)
is_ghf = isinstance(self, ghf.GHF)
is_rhf = (not is_uhf) and (not is_ghf)
is_rohf = isinstance(self, rohf.ROHF)
is_uhf = self.istype('UHF')
is_rhf = self.istype('RHF')
is_rohf = self.istype('ROHF')

sigma = self.sigma
if self.smearing_method.lower() == 'fermi':
Expand Down Expand Up @@ -207,14 +205,13 @@ def _get_entropy(self, mo_energy, mo_occ, mu):
return entropy

def get_grad(self, mo_coeff, mo_occ, fock=None):
from pyscf.scf import uhf
if (self.sigma == 0) or (not self.sigma) or (not self.smearing_method):
return super().get_grad(mo_coeff, mo_occ, fock)

if fock is None:
dm1 = self.make_rdm1(mo_coeff, mo_occ)
fock = self.get_hcore() + self.get_veff(self.mol, dm1)
if isinstance(self, uhf.UHF):
if self.istype('UHF'):
ga = _get_grad_tril(mo_coeff[0], mo_occ[0], fock[0])
gb = _get_grad_tril(mo_coeff[1], mo_occ[1], fock[1])
return numpy.hstack((ga,gb))
Expand All @@ -241,7 +238,7 @@ def frac_occ_(mf, tol=1e-3):
>>> mf = scf.addons.frac_occ(mf)
>>> mf.run()
'''
from pyscf.scf import hf, uhf, rohf
from pyscf.scf import uhf, rohf
old_get_occ = mf.get_occ
mol = mf.mol

Expand All @@ -263,7 +260,7 @@ def guess_occ(mo_energy, nocc):
frac_occ_lst = numpy.zeros_like(mo_energy, dtype=bool)
return mo_occ, numpy.where(frac_occ_lst)[0], homo, lumo

if isinstance(mf, uhf.UHF):
if mf.istype('UHF'):
def get_occ(mo_energy, mo_coeff=None):
nocca, noccb = mol.nelec
mo_occa, frac_lsta, homoa, lumoa = guess_occ(mo_energy[0], nocca)
Expand All @@ -289,7 +286,7 @@ def get_occ(mo_energy, mo_coeff=None):
mo_occ = old_get_occ(mo_energy, mo_coeff)
return mo_occ

elif isinstance(mf, rohf.ROHF):
elif mf.istype('ROHF'):
def get_occ(mo_energy, mo_coeff=None):
nocca, noccb = mol.nelec
mo_occa, frac_lsta, homoa, lumoa = guess_occ(mo_energy, nocca)
Expand Down Expand Up @@ -336,7 +333,7 @@ def get_grad(mo_coeff, mo_occ, fock):
g[uniq_var_b] += fockb[uniq_var_b]
return g[uniq_var_a | uniq_var_b]

elif isinstance(mf, hf.RHF):
elif mf.istype('RHF'):
def get_occ(mo_energy, mo_coeff=None):
nocc = (mol.nelectron+1) // 2 # n_docc + n_socc
mo_occ, frac_lst, homo, lumo = guess_occ(mo_energy, nocc)
Expand Down Expand Up @@ -379,7 +376,7 @@ def dynamic_occ_(mf, tol=1e-3):
'''
Dynamically adjust the occupancy to avoid degeneracy between HOMO and LUMO
'''
assert (isinstance(mf, hf.RHF))
assert mf.istype('RHF')
old_get_occ = mf.get_occ
def get_occ(mo_energy, mo_coeff=None):
mol = mf.mol
Expand Down Expand Up @@ -431,7 +428,7 @@ def float_occ_(mf):
Determine occupation of alpha and beta electrons based on energy spectrum
'''
from pyscf.scf import uhf
assert (isinstance(mf, uhf.UHF))
assert mf.istype('UHF')
def get_occ(mo_energy, mo_coeff=None):
mol = mf.mol
ee = numpy.sort(numpy.hstack(mo_energy))
Expand Down Expand Up @@ -477,22 +474,22 @@ def mom_occ_(mf, occorb, setocc):
iteration.'''
from pyscf.scf import uhf, rohf
log = logger.Logger(mf.stdout, mf.verbose)
if isinstance(mf, uhf.UHF):
if mf.istype('UHF'):
coef_occ_a = occorb[0][:, setocc[0]>0]
coef_occ_b = occorb[1][:, setocc[1]>0]
elif isinstance(mf, rohf.ROHF):
elif mf.istype('ROHF'):
if mf.mol.spin != (numpy.sum(setocc[0]) - numpy.sum(setocc[1])):
raise ValueError('Wrong occupation setting for restricted open-shell calculation.')
coef_occ_a = occorb[:, setocc[0]>0]
coef_occ_b = occorb[:, setocc[1]>0]
else: # GHF, and DHF
assert setocc.ndim == 1

if isinstance(mf, (uhf.UHF, rohf.ROHF)):
if mf.istype('UHF') or mf.istype('ROHF'):
def get_occ(mo_energy=None, mo_coeff=None):
if mo_energy is None: mo_energy = mf.mo_energy
if mo_coeff is None: mo_coeff = mf.mo_coeff
if isinstance(mf, rohf.ROHF):
if mf.istype('ROHF'):
mo_coeff = numpy.array([mo_coeff, mo_coeff])
mo_occ = numpy.zeros_like(setocc)
nocc_a = int(numpy.sum(setocc[0]))
Expand Down Expand Up @@ -521,7 +518,7 @@ def get_occ(mo_energy=None, mo_coeff=None):
nocc_b, int(numpy.sum(mo_occ[1])))

#output 1-dimension occupation number for restricted open-shell
if isinstance(mf, rohf.ROHF): mo_occ = mo_occ[0, :] + mo_occ[1, :]
if mf.istype('ROHF'): mo_occ = mo_occ[0, :] + mo_occ[1, :]
return mo_occ
else:
def get_occ(mo_energy=None, mo_coeff=None):
Expand Down Expand Up @@ -780,18 +777,18 @@ def convert_to_uhf(mf, out=None, remove_df=False):
'''
from pyscf import scf
from pyscf import dft
assert (isinstance(mf, hf.SCF))
assert (isinstance(mf, scf.hf.SCF))

logger.debug(mf, 'Converting %s to UHF', mf.__class__)

if isinstance(mf, scf.ghf.GHF):
if mf.istype('GHF'):
raise NotImplementedError

elif out is not None:
assert (isinstance(out, scf.uhf.UHF))
assert out.istype('UHF')
out = _update_mf_without_soscf(mf, out, remove_df)

elif isinstance(mf, scf.uhf.UHF):
elif mf.istype('UHF'):
# Remove with_df for SOSCF method because the post-HF code checks the
# attribute .with_df to identify whether an SCF object is DF-SCF method.
# with_df in SOSCF is used in orbital hessian approximation only. For the
Expand Down Expand Up @@ -884,7 +881,7 @@ def convert_to_rhf(mf, out=None, remove_df=False):
'''
from pyscf import scf
from pyscf import dft
assert (isinstance(mf, hf.SCF))
assert (isinstance(mf, scf.hf.SCF))

logger.debug(mf, 'Converting %s to RHF', mf.__class__)

Expand All @@ -893,15 +890,15 @@ def convert_to_rhf(mf, out=None, remove_df=False):
else:
nelec = mf.nelec

if isinstance(mf, scf.ghf.GHF):
if mf.istype('GHF'):
raise NotImplementedError

elif out is not None:
assert (isinstance(out, scf.hf.RHF))
assert out.istype('RHF')
out = _update_mf_without_soscf(mf, out, remove_df)

elif (isinstance(mf, scf.hf.RHF) or
(nelec[0] != nelec[1] and isinstance(mf, scf.rohf.ROHF))):
elif (mf.istype('RHF') or
(nelec[0] != nelec[1] and mf.istype('ROHF'))):
if getattr(mf, '_scf', None):
return _update_mf_without_soscf(mf, mf._scf.copy(), remove_df)
else:
Expand Down Expand Up @@ -951,15 +948,15 @@ def convert_to_ghf(mf, out=None, remove_df=False):
'''
from pyscf import scf
from pyscf import dft
assert (isinstance(mf, hf.SCF))
assert (isinstance(mf, scf.hf.SCF))

logger.debug(mf, 'Converting %s to GHF', mf.__class__)

if out is not None:
assert (isinstance(out, scf.ghf.GHF))
assert out.istype('GHF')
out = _update_mf_without_soscf(mf, out, remove_df)

elif isinstance(mf, scf.ghf.GHF):
elif mf.istype('GHF'):
if getattr(mf, '_scf', None):
return _update_mf_without_soscf(mf, mf._scf.copy(), remove_df)
else:
Expand All @@ -985,11 +982,10 @@ def convert_to_ghf(mf, out=None, remove_df=False):
return _update_mo_to_ghf_(mf, out)

def _update_mo_to_uhf_(mf, mf1):
from pyscf import scf
if mf.mo_energy is None:
return mf1

if isinstance(mf, scf.uhf.UHF):
if mf.istype('UHF') or mf.istype('KUHF'):
mf1.mo_occ = mf.mo_occ
mf1.mo_coeff = mf.mo_coeff
mf1.mo_energy = mf.mo_energy
Expand All @@ -1011,11 +1007,10 @@ def _update_mo_to_uhf_(mf, mf1):
return mf1

def _update_mo_to_rhf_(mf, mf1):
from pyscf import scf
if mf.mo_energy is None:
return mf1

if isinstance(mf, scf.hf.RHF): # RHF/ROHF/KRHF/KROHF
if mf.istype('RHF') or mf.istype('KRHF'): # RHF/ROHF/KRHF/KROHF
mf1.mo_occ = mf.mo_occ
mf1.mo_coeff = mf.mo_coeff
mf1.mo_energy = mf.mo_energy
Expand All @@ -1034,11 +1029,12 @@ def _update_mo_to_rhf_(mf, mf1):
return mf1

def _update_mo_to_ghf_(mf, mf1):
from pyscf import scf
if mf.mo_energy is None:
return mf1

if isinstance(mf, scf.hf.RHF): # RHF
if mf.istype('KSCF'):
raise NotImplementedError('KSCF')
elif mf.istype('RHF'):
nao, nmo = mf.mo_coeff.shape
orbspin = get_ghf_orbspin(mf.mo_energy, mf.mo_occ, True)

Expand Down
14 changes: 14 additions & 0 deletions pyscf/scf/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2007,6 +2007,20 @@ def to_gpu(self):
'''
raise NotImplementedError

def istype(self, type_code):
'''
Checks if the object is an instance of the class specified by the type_code.
type_code can be a class or a str. If the type_code is a class, it is
equivalent to the Python built-in function `isinstance`. If the type_code
is a str, it checks the type_code against the names of the object and all
its parent classes.
'''
if isinstance(type_code, type):
# type_code is a class
return isinstance(self, type_code)

return any(type_code == t.__name__ for t in self.__class__.__mro__)


class KohnShamDFT:
'''A mock DFT base class
Expand Down
Loading

0 comments on commit 2fc341a

Please sign in to comment.