From 412e80ca29097703866a021de1e25f5ec60b887e Mon Sep 17 00:00:00 2001 From: Alexander Hampel Date: Mon, 20 May 2024 14:13:10 -0400 Subject: [PATCH] [feat] allow PCB to read from TRIQS TB object --- .../postprocessing/plot_correlated_bands.py | 38 ++++++++++++------- test/python/test_plot_correlated_bands.py | 14 ++++++- 2 files changed, 38 insertions(+), 14 deletions(-) diff --git a/python/solid_dmft/postprocessing/plot_correlated_bands.py b/python/solid_dmft/postprocessing/plot_correlated_bands.py index fbb343d9..f688583f 100644 --- a/python/solid_dmft/postprocessing/plot_correlated_bands.py +++ b/python/solid_dmft/postprocessing/plot_correlated_bands.py @@ -668,12 +668,12 @@ def plot_kslice(fig, ax, alatt_k_w, tb_data, freq_dict, n_orb, tb_dict, tb=True, return ax -def get_dmft_bands(n_orb, w90_path, w90_seed, mu_tb, add_spin=False, add_lambda=None, add_local=None, +def get_dmft_bands(n_orb, mu_tb, w90_path=None, w90_seed=None, TB_obj=None, add_spin=False, add_lambda=None, add_local=None, with_sigma=None, fermi_slice=False, qp_bands=False, orbital_order_to=None, add_mu_tb=False, band_basis=False, proj_on_orb=None, trace=True, eta=0.0, mu_shift=0.0, proj_nuk=None, **specs): ''' - Extract tight-binding from given w90 seed_hr.dat and seed.wout files, and then extract from + Extract tight-binding from given w90 seed_hr.dat and seed.wout files or alternatively given TB_obj, and then extract from given solid_dmft calculation the self-energy and construct the spectral function A(k,w) on given k-path. @@ -681,10 +681,14 @@ def get_dmft_bands(n_orb, w90_path, w90_seed, mu_tb, add_spin=False, add_lambda= ---------- n_orb : int Number of Wannier orbitals in seed_hr.dat + mu_tb : float + Chemical potential of tight-binding calculation w90_path : string Path to w90 files w90_seed : string Seed of wannier90 calculation, i.e. seed_hr.dat and seed.wout + TB_obj : TB object + Tight-binding object from TB_from_wannier90 add_spin : bool, default=False Extend w90 Hamiltonian by spin indices add_lambda : float, default=None @@ -764,19 +768,27 @@ def get_dmft_bands(n_orb, w90_path, w90_seed, mu_tb, add_spin=False, add_lambda= if isinstance(proj_nuk, np.ndarray) and not band_basis: band_basis = True - # set up Wannier Hamiltonian - n_orb = 2 * n_orb if add_spin else n_orb - change_of_basis = change_basis(n_orb, orbital_order_to, orbital_order_w90) - H_add_loc = np.zeros((n_orb, n_orb), dtype=complex) - if not isinstance(add_local, type(None)): - assert np.shape(add_local) == (n_orb, n_orb), 'add_local must have dimension (n_orb, n_orb), but has '\ - f'dimension {np.shape(add_local)}' - H_add_loc += add_local - if add_spin and add_lambda: - H_add_loc += lambda_matrix_w90_t2g(add_lambda) + if TB_obj is None: + assert w90_path is not None and w90_seed is not None, 'Please provide either a TB object or a path to the wannier90 files' + # set up Wannier Hamiltonian + n_orb = 2 * n_orb if add_spin else n_orb + change_of_basis = change_basis(n_orb, orbital_order_to, orbital_order_w90) + H_add_loc = np.zeros((n_orb, n_orb), dtype=complex) + if not isinstance(add_local, type(None)): + assert np.shape(add_local) == (n_orb, n_orb), 'add_local must have dimension (n_orb, n_orb), but has '\ + f'dimension {np.shape(add_local)}' + H_add_loc += add_local + if add_spin and add_lambda: + H_add_loc += lambda_matrix_w90_t2g(add_lambda) + + tb = TB_from_wannier90(path=w90_path, seed=w90_seed, extend_to_spin=add_spin, add_local=H_add_loc) + else: + assert not add_spin, 'add_spin is only valid when reading from wannier90 files' + change_of_basis = change_basis(n_orb, orbital_order_to, orbital_order_w90) + tb = TB_obj + eta = eta * 1j - tb = TB_from_wannier90(path=w90_path, seed=w90_seed, extend_to_spin=add_spin, add_local=H_add_loc) # print local H(R) h_of_r = np.einsum('ij, jk -> ik', np.linalg.inv(change_of_basis), np.einsum('ij, jk -> ik', tb.hoppings[(0, 0, 0)], change_of_basis)) if n_orb <= 12: diff --git a/test/python/test_plot_correlated_bands.py b/test/python/test_plot_correlated_bands.py index a9e1d29b..617e1ffa 100644 --- a/test/python/test_plot_correlated_bands.py +++ b/test/python/test_plot_correlated_bands.py @@ -121,7 +121,7 @@ def test_get_kslice_nokz(self): assert np.allclose(tb_data['e_mat'], emat_ref) assert np.allclose(alatt_k_w, Akw_ref) - def test_get_dmft_bands_reg_mesh(self): + def test_get_dmft_bands_reg_mesh_read_TB_obj(self): tb_bands = {'kmesh': 'regular', 'n_k': 7} tb_data, alatt_k_w, freq_dict = pcb.get_dmft_bands(with_sigma='calc', add_mu_tb=True, @@ -135,5 +135,17 @@ def test_get_dmft_bands_reg_mesh(self): assert np.allclose(tb_data['e_mat'], emat_ref) assert np.allclose(alatt_k_w, Akw_ref) + # read now from TB_obj + w90_dict = {'TB_obj': tb_data['tb'], 'mu_tb': 12.3958, 'n_orb': 3, + 'orbital_order_w90': ['dxz', 'dyz', 'dxy']} + + tb_data_obj, alatt_k_w_obj, freq_dict_obj = pcb.get_dmft_bands(with_sigma='calc', add_mu_tb=True, + orbital_order_to=self.orbital_order_to, + **w90_dict, **tb_bands, **self.sigma_dict) + + assert np.allclose(tb_data_obj['e_mat'], emat_ref) + assert np.allclose(alatt_k_w_obj, Akw_ref) + + if __name__ == '__main__': unittest.main()