diff --git a/h5_helpers/selfcal_quality.py b/h5_helpers/selfcal_quality.py index 4b951ef3..f84a1e3e 100644 --- a/h5_helpers/selfcal_quality.py +++ b/h5_helpers/selfcal_quality.py @@ -39,8 +39,11 @@ def __init__(self, folder: str = None, remote_only: bool = False, international_ # selfcal folder self.folder = folder + # merged selfcal h5parms - self.h5s = glob(f"{self.folder}/merged_selfcalcyle*.h5") + self.h5s = [h5 for h5 in glob(f"{self.folder}/merged_selfcalcyle*.h5") if 'linearfulljones' not in h5] + if len(self.h5s)==0: + self.h5s = glob(f"{self.folder}/merged_selfcalcyle*.h5") assert len(self.h5s) != 0, "No h5 files found" # select all sources @@ -50,13 +53,14 @@ def __init__(self, folder: str = None, remote_only: bool = False, international_ # select all fits images fitsfiles = sorted(glob(self.folder + "/*MFS-I-image.fits")) - if len(fitsfiles) == 0: + if len(fitsfiles) == 0 or '000' not in fitsfiles[0]: fitsfiles = sorted(glob(self.folder + "/*MFS-image.fits")) self.fitsfiles = [f for f in fitsfiles if 'arcsectaper' not in f] + assert len(self.fitsfiles) != 0, "No fits files found" # select all fits model images modelfiles = sorted(glob(self.folder + "/*MFS-I-model.fits")) - if len(modelfiles) == 0: + if len(modelfiles) == 0 or '000' not in modelfiles[0]: modelfiles = sorted(glob(self.folder + "/*MFS-model.fits")) self.modelfiles = [f for f in modelfiles if 'arcsectaper' not in f] @@ -65,7 +69,7 @@ def __init__(self, folder: str = None, remote_only: bool = False, international_ self.international_only = international_only self.dutch_only = dutch_only - self.textfile = open('selfcal_performance.csv', 'w') + self.textfile = open(f'selfcal_performance_{self.sourcename}.csv', 'w') self.writer = csv.writer(self.textfile) self.writer.writerow(['solutions', 'dirty'] + [str(i) for i in range(len(self.fitsfiles))]) @@ -258,14 +262,12 @@ def get_solution_scores(self, h5_1: str = None, h5_2: str = None): # take std from ratio of previous and current selfcal cycle if h5_2 is not None: - if len(pols1) != len(pols2): if min(len(pols1), len(pols2)) == 1: vals1 = np.take(vals1, [0], axis=axes.index('pol')) vals2 = np.take(vals2, [0], axis=axes.index('pol')) weights1 = np.take(weights1, [0], axis=axes.index('pol')) weights2 = np.take(weights2, [0], axis=axes.index('pol')) - elif min(len(pols1), len(pols2)) == 2: vals1 = np.take(vals1, [0, -1], axis=axes.index('pol')) vals2 = np.take(vals2, [0, -1], axis=axes.index('pol')) @@ -296,7 +298,6 @@ def solution_stability(self): # loop over sources to get scores for k, source in enumerate(self.sources): - print(source) sub_h5s = sorted([h5 for h5 in self.h5s if source in h5]) phase_scores = [] amp_scores = [] @@ -475,7 +476,7 @@ def main(): bestcycle_solutions, accept_solutions = sq.solution_stability() bestcycle_image, accept_image = sq.image_stability() sq.textfile.close() - df = pd.read_csv('selfcal_performance.csv').set_index('solutions').T + df = pd.read_csv(f'selfcal_performance_{sq.sourcename}.csv').set_index('solutions').T print(df) df.to_csv(f'selfcal_performance_{sq.sourcename}.csv', index=False)