From 802868eda12574133846fca1b7314a1905db3b50 Mon Sep 17 00:00:00 2001 From: phibeck Date: Fri, 27 Jan 2023 11:22:41 -0500 Subject: [PATCH] fix: improve rootfinder in PCB --- .../postprocessing/plot_correlated_bands.py | 41 ++++++++++++++----- 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/python/solid_dmft/postprocessing/plot_correlated_bands.py b/python/solid_dmft/postprocessing/plot_correlated_bands.py index cff2b638..bb1fe6b2 100644 --- a/python/solid_dmft/postprocessing/plot_correlated_bands.py +++ b/python/solid_dmft/postprocessing/plot_correlated_bands.py @@ -38,6 +38,7 @@ from matplotlib import cm from scipy.optimize import brentq from scipy.interpolate import interp1d +from scipy.signal import argrelextrema import numpy as np import itertools import skimage.measure @@ -367,13 +368,24 @@ def invert_and_trace(w, eta, mu, e_mat, sigma, proj=None): else: assert n_kx == n_ky, 'Not implemented for N_kx != N_ky' + + def search_for_extrema(data): + # return None for no extrema, [] if ends of interval are the only extrema, + # list of indices if local extrema are present + answer = np.all(data > 0) or np.all(data < 0) + if answer: + return + else: + roots = [] + roots.append(list(argrelextrema(data, np.greater)[0])) + roots.append(list(argrelextrema(data, np.less)[0])) + roots = sorted([item for sublist in roots for item in sublist]) + return roots + alatt_k_w = np.zeros((n_kx, n_ky, n_orb)) + # go through grid horizontally, then vertically for it in range(2): kslice = np.zeros((n_kx, n_ky, n_orb)) - if it == 0: - def kslice_interp(ik, orb): return interp1d(range(n_kx), kslice[:, ik, orb]) - else: - def kslice_interp(ik, orb): return interp1d(range(n_kx), kslice[ik, :, orb]) for ik1 in range(n_kx): e_temp = e_mat[:, :, :, ik1] if it == 0 else e_mat[:, :, ik1, :] @@ -383,12 +395,21 @@ def kslice_interp(ik, orb): return interp1d(range(n_kx), kslice[ik, :, orb]) kslice[k1, k2] = e_val for orb in range(n_orb): - try: - x0 = brentq(kslice_interp(ik1, orb), 0, n_kx - 1) - k1, k2 = [int(np.floor(x0)), ik1] if it == 0 else [ik1, int(np.floor(x0))] - alatt_k_w[k1, k2, orb] += 1 - except ValueError: - pass + temp_kslice = kslice[:,ik1,orb] if it == 0 else kslice[ik1,:,orb] + roots = search_for_extrema(temp_kslice) + # iterate through sections between extrema + if roots is not None: + idx_1 = 0 + for root_ct in range(len(roots) + 1): + idx_2 = roots[root_ct] if root_ct < len(roots) else n_kx + root_section = temp_kslice[idx_1:idx_2+1] + try: + x0 = brentq(interp1d(np.linspace(idx_1, idx_2, len(root_section)), root_section), idx_1, idx_2) + k1, k2 = [int(np.floor(x0)), ik1] if it == 0 else [ik1, int(np.floor(x0))] + alatt_k_w[k1, k2, orb] += 1 + except(ValueError): + pass + idx_1 = idx_2 alatt_k_w[np.where(alatt_k_w > 1)] = 1