diff --git a/pyscf/df/grad/rhf.py b/pyscf/df/grad/rhf.py index eade52869b..d7669fda57 100644 --- a/pyscf/df/grad/rhf.py +++ b/pyscf/df/grad/rhf.py @@ -368,26 +368,25 @@ def _decompose_rdm1 (mf_grad, mol, dm): return orbol, orbor def _gen_metric_solver(int2c, decompose_j2c='CD', lindep=LINEAR_DEP_THRESHOLD): - decompose_j2c = decompose_j2c.upper() - j2c_solver = None - if decompose_j2c == 'CD': + if decompose_j2c.upper() == 'CD': try: j2c = scipy.linalg.cho_factor(int2c, lower=True) def j2c_solver(v): return scipy.linalg.cho_solve(j2c, v, overwrite_b=True) + return j2c_solver + except (numpy.linalg.LinAlgError, scipy.linalg.LinAlgError): pass - if j2c_solver is None or decompose_j2c != 'CD': - w, v = scipy.linalg.eigh(int2c) - mask = w > lindep - v1 = v[:,mask] - j2c = lib.dot(v1/w[mask], v1.conj().T) - def j2c_solver(v): - if v.ndim == 2: - return lib.dot(j2c, v) - else: - return j2c.dot(v) + w, v = scipy.linalg.eigh(int2c) + mask = w > lindep + v1 = v[:,mask] + j2c = lib.dot(v1/w[mask], v1.conj().T) + def j2c_solver(v): # noqa: F811 + if v.ndim == 2: + return lib.dot(j2c, v) + else: + return j2c.dot(v) return j2c_solver def _cho_solve_rhojk (mf_grad, mol, auxmol, orbol, orbor,