diff --git a/phasediff_scores/find_solint.py b/phasediff_scores/find_solint.py index bcaeff88..e7ff966c 100644 --- a/phasediff_scores/find_solint.py +++ b/phasediff_scores/find_solint.py @@ -1,5 +1,10 @@ -from source_selection.phasediff_output import GetSolint +import sys +from pathlib import Path + +# Add the parent directory to sys.path +sys.path.append(str(Path(__file__).resolve().parent.parent)) +from source_selection.phasediff_output import GetSolint if __name__ == "__main__": diff --git a/source_detection/crossmatch.py b/source_detection/crossmatch.py index e9714ca5..1ed0e07e 100644 --- a/source_detection/crossmatch.py +++ b/source_detection/crossmatch.py @@ -1,4 +1,4 @@ -#RUN python source_detection/crossmatch.py --cat1 /home/jurjen/Documents/ELAIS/catalogues/finalcat03/*.fits --cat2 /home/jurjen/Documents/ELAIS/catalogues/pybdsf_sources_6asec.fits +"""Code to crossmatch catalogues""" from astropy.coordinates import SkyCoord from astropy import units as u @@ -9,6 +9,9 @@ import matplotlib.pyplot as plt from matplotlib.colors import LogNorm import warnings +import scienceplots + +plt.style.use('science') # Disable all warnings temporarily warnings.filterwarnings("ignore") @@ -17,6 +20,7 @@ def find_matches(cat1, cat2, separation_asec): """ Find crossmatches with two catalogues """ + catalog1 = Table.read(cat1, format='fits') catalog2 = Table.read(cat2, format='fits') @@ -39,18 +43,17 @@ def find_matches(cat1, cat2, separation_asec): return matched_sources_catalog1, matched_sources_catalog2 -def remove_non_matches(cat1, cat2, separation_asec, rms_thresh=5.5, flux_thresh=4): +def separation_match(cat1, cat2, separation_asec): """ - Find the non-crossmatches between two catalogues and remove those from catalogue 1 + Find non-crossmatches between two catalogues and remove those from catalogue 1 based on distance in arcsec :param cat1: catalogue 1 :param cat2: catalogue 2 - :param separation_asec: max separation between catalogue matches - :param rms_thresh: consider only non-detections below this RMS threshold - :param flux_thresh: flux ratio threshold between catalogue 1 and catalogue 2 + :param separation_asec: max separation between catalogue matches in arcsec :return: corrected catalogue 1, removed sources from catalogue 1 """ + catalog1 = Table.read(cat1, format='fits') catalog2 = Table.read(cat2, format='fits') @@ -61,27 +64,16 @@ def remove_non_matches(cat1, cat2, separation_asec, rms_thresh=5.5, flux_thresh= # Perform the crossmatch using Astropy's match_coordinates_sky function idx_catalog2, separation, _ = match_coordinates_sky(coords1, coords2) catalog1['separation_cat2'] = separation - # catalog1['Total_flux_ratio'] = catalog1['Total_flux']/catalog2[idx_catalog2]['Total_flux'] # Define a maximum separation threshold (adjustplots as needed) max_sep_threshold_large = separation_asec * u.arcsec - # max_sep_threshold_small = separation_asec/2 * u.arcsec - # factor = max_sep_threshold_large.value/max_sep_threshold_small.value # Sources below RMS threshold and above separation threshold non_matches_large = catalog1[(catalog1['separation_cat2'] > max_sep_threshold_large)] - # # Sources below RMS threshold/2 and above smaller separation threshold - # non_matches_small = catalog1[(catalog1['separation_cat2'] > max_sep_threshold_small) - # & (catalog1['Peak_flux'] < rms_thresh/factor * catalog1['Isl_rms'])] - # # Sources above flux ratio threshold and below threshold/2 - # non_matches_flux = catalog1[((catalog1['Total_flux_ratio']>flux_thresh) | (catalog1['Total_flux_ratio']<1/flux_thresh)) - # & (catalog1['Peak_flux'] < rms_thresh/factor * catalog1['Isl_rms'])] non_match_mask_large = np.sum([non_matches_large['Source_id'] == i for i in catalog1['Source_id']], axis=1).astype(bool) - # non_match_mask_small = np.sum([non_matches_small['Source_id'] == i for i in catalog1['Source_id']], axis=1).astype(bool) - # non_match_mask_flux = np.sum([non_matches_flux['Source_id'] == i for i in catalog1['Source_id']], axis=1).astype(bool) - non_match_mask = non_match_mask_large #| non_match_mask_small | non_match_mask_flux + non_match_mask = non_match_mask_large catalog1_corrected = catalog1[~non_match_mask] catalog1_removed = catalog1[non_match_mask] @@ -109,6 +101,7 @@ def crossmatch_itself(catalog, min_sep=0.15): def make_plots(cat, res=0.3): + """Make plots""" # make peak flux plot plt.hist(np.log10(cat['Peak_flux'] * 1000), bins=30) @@ -119,7 +112,7 @@ def make_plots(cat, res=0.3): # make total flux plot plt.hist(np.log10(cat['Total_flux'] * 1000), bins=30) plt.xlabel('Total flux (mJy)') - plt.savefig('total_flux.png') + plt.savefig('total_flux.png', dpi=150) plt.close() ############### 2D HISTOGRAM ############### @@ -144,10 +137,33 @@ def make_plots(cat, res=0.3): plt.plot([mediandRA]*len(axs), axs, color='black', linestyle='--') plt.xlim(-6, 6) plt.ylim(-6, 6) - plt.savefig('dRA_dDEC.png') + plt.savefig('dRA_dDEC.png', dpi=150) plt.close() + ############# Flux ratio 6" ############## + subcat = cat[(cat['S_Code'] == 'S') & (cat['Total_flux'] * 1000 > 0.1)] + subcat = subcat[(subcat['dDEC_0.3']*3600 < 0.15) & (subcat['dRA_0.3']*3600 < 0.15)] + R = subcat['Total_flux_6'] / subcat['Total_flux'] + plt.scatter(subcat['Total_flux_6'], R, color='darkred') + plt.xscale('log') + plt.yscale('log') + plt.xlabel("Total flux") + plt.ylabel("Flux ratio") + plt.title(f'Median ratio: {round(np.median(R[np.isfinite(R)]), 2)}') + plt.savefig('lotssdeep_ratio.png', dpi=150) + + ############# Peak flux over Total flux ############## + subcat = cat[(cat['S_Code'] == 'S') & (cat['Total_flux'] * 1000 > 1)] + R = subcat['Peak_flux'] / subcat['Total_flux'] + plt.scatter(subcat['Total_flux'], R) + plt.xscale('log') + plt.xlabel("Total flux") + plt.ylabel("Peakflux/Totalflux") + plt.title(f'Median ratio: {round(np.median(R[np.isfinite(R)]), 2)}') + plt.savefig('peak_total.png', dpi=150) + + def merge_with_table(catalog1, catalog2, sep=6, res=0.3): """Merge with other table""" @@ -188,9 +204,7 @@ def merge_with_table(catalog1, catalog2, sep=6, res=0.3): def parse_args(): - """ - Parse input arguments - """ + """Parse input arguments""" parser = argparse.ArgumentParser(description='Catalogue matching') parser.add_argument('--cat1', nargs='+', help='Catalogue 1 (can be multiple)', default=None) @@ -200,7 +214,7 @@ def parse_args(): "6asec/pybdsf_sources_6asec.fits") parser.add_argument('--separation_asec', type=float, default=6, help= 'minimal separation between catalogue 1 and catalogue 2') - parser.add_argument('--cat_prefix', type=str) + parser.add_argument('--source_id_prefix', type=str) parser.add_argument('--out_table', type=str, default='final_merged.fits') parser.add_argument('--resolution', type=float, default=0.3) parser.add_argument('--only_plot', action='store_true', help='make only plot') @@ -211,8 +225,8 @@ def parse_args(): def main(): """Main""" - outcols = ['Cat_id', 'Isl_id', 'RA','E_RA','DEC','E_DEC','Total_flux','E_Total_flux','Peak_flux','E_Peak_flux', - 'Maj','E_Maj','Min','E_Min','PA','E_PA', 'S_Code', 'Isl_rms'] + outcols = ['Cat_id', 'Isl_id', 'RA', 'E_RA', 'DEC','E_DEC', 'Total_flux', 'E_Total_flux', 'Peak_flux', 'E_Peak_flux', + 'Maj', 'E_Maj', 'Min', 'E_Min', 'PA', 'E_PA', 'S_Code', 'Isl_rms'] args = parse_args() @@ -222,9 +236,9 @@ def main(): else: for n, cat in enumerate(args.cat1): print(cat) - catalog1_new, _ = remove_non_matches(cat, args.cat2, args.separation_asec) - if args.cat_prefix is not None: - catalog1_new['Cat_id'] = [f'{args.cat_prefix}_{id}' for id in list(catalog1_new['Source_id'])] + catalog1_new, _ = separation_match(cat, args.cat2, args.separation_asec) + if args.source_id_prefix is not None: + catalog1_new['Cat_id'] = [f'{args.source_id_prefix}_{id}' for id in list(catalog1_new['Source_id'])] else: catalog1_new['Cat_id'] = [f'{cat.split("/")[-1].split("_")[1]}_{id}' for id in list(catalog1_new['Source_id'])] if n==0: @@ -250,3 +264,6 @@ def main(): if __name__ == '__main__': main() + +# python source_detection/crossmatch.py --cat1 /home/jurjen/Documents/ELAIS/catalogues/finalcat03/*.fits --cat2 /home/jurjen/Documents/ELAIS/catalogues/pybdsf_sources_6asec.fits --out_table final_merged_03.fits +# python source_detection/crossmatch.py --cat1 /home/jurjen/Documents/ELAIS/catalogues/finalcat06/*.fits --cat2 /home/jurjen/Documents/ELAIS/catalogues/pybdsf_sources_6asec.fits --out_table final_merged_06.fits diff --git a/source_selection/selfcal_selection.py b/source_selection/selfcal_selection.py index e538e4a6..70fe3c70 100644 --- a/source_selection/selfcal_selection.py +++ b/source_selection/selfcal_selection.py @@ -24,26 +24,31 @@ import pandas as pd from cv2 import bilateralFilter from typing import Union +import warnings + +# Ignore all warnings +warnings.filterwarnings('ignore') + +plt.style.use('ggplot') class SelfcalQuality: - def __init__(self, folder: str, station: str): + def __init__(self, h5s: list, fitsim: list, station: str): """ Determine quality of selfcal from facetselfcal.py + :param fitsim: fits images + :param h5s: h5parm solutions :param folder: path to directory where selfcal ran :param station: which stations to consider [dutch, remote, international, debug] """ - # selfcal folder - self.folder = folder - # merged selfcal h5parms - self.h5s = [h5 for h5 in glob(f"{self.folder}/merged_addCS_selfcalcyle*.h5") if 'linearfulljones' not in h5] - if len(self.h5s) == 0: - self.h5s = glob(f"{self.folder}/merged_addCS_selfcalcyle*.h5") - if len(self.h5s) == 0: - raise FileNotFoundError("WARNING: No h5 files found") - # assert len(self.h5s) != 0, "No h5 files found" + self.h5s = h5s + assert len(self.h5s) != 0, "WARNING: No h5 files given" + + # selfcal images + self.fitsfiles = fitsim + assert len(self.fitsfiles) != 0, "No fits files found" # select all sources sources = [] @@ -51,21 +56,18 @@ def __init__(self, folder: str, station: str): matches = re.findall(r'selfcalcyle\d+_(.*?)\.', h5.split('/')[-1]) assert len(matches) == 1 sources.append(matches[0]) - self.sources = set(sources) - assert len(self.sources) > 0, "No sources found" - # select all fits images - fitsfiles = sorted(glob(self.folder + "/*MFS-I-image.fits")) - if len(fitsfiles) == 0 or '000' not in fitsfiles[0]: - fitsfiles = sorted(glob(self.folder + "/*MFS-image.fits")) - self.fitsfiles = [f for f in fitsfiles if 'arcsectaper' not in f] - assert len(self.fitsfiles) != 0, "No fits files found" - # for phase/amp evolution self.station = station + self.station_codes = ( + ('CS',) if self.station == 'dutch' else + ('RS',) if self.station == 'remote' else + ('IE', 'SE', 'PL', 'UK', 'LV') # i.e.: if self.station == 'international' + ) + # output csv self.textfile = open(f'selfcal_performance.csv', 'w') self.writer = csv.writer(self.textfile) @@ -134,8 +136,6 @@ def make_figure(vals1=None, vals2=None, label1=None, label2=None, plotname=None) # ax1.set_ylim(0, np.pi/2) ax1.set_xlim(1, 11) ax1.grid(False) - ax1.grid('off') - ax1.grid(None) if vals2 is not None: @@ -151,8 +151,6 @@ def make_figure(vals1=None, vals2=None, label1=None, label2=None, plotname=None) # ax2.set_ylim(0, 2) ax2.set_xlim(1, 11) ax2.grid(False) - ax2.grid('off') - ax2.grid(None) fig.tight_layout() @@ -189,6 +187,30 @@ def image_entropy(self, fitsfile: str = None): print(f"Entropy: {val}") return val + def filter_stations(self, station_names): + """Generate indices of filtered stations""" + + if self.station == 'debug': + return list(range(len(station_names))) + + output_stations = [ + i for i, station_name in enumerate(station_names) + if any(station_code in station_name for station_code in self.station_codes) + ] + + return output_stations + + def print_station_names(self, h5): + """Print station names""" + H = tables.open_file(h5) + v_make_utf8 = np.vectorize(self.make_utf8) + stations = v_make_utf8(H.root.sol000.antenna[:]['name']) + + stations_used = ', '.join([station_name for station_name in stations + if any(station_code in station_name for station_code in self.station_codes)]) + print(f'Used the following stations: {stations_used}') + return self + def get_solution_scores(self, h5_1: str, h5_2: str = None): """ Get solution scores @@ -202,31 +224,34 @@ def get_solution_scores(self, h5_1: str, h5_2: str = None): def extract_data(tables_path): with tables.open_file(tables_path) as f: - return ( - [self.make_utf8(station) for station in f.root.sol000.antenna[:]['name']], - self.make_utf8(f.root.sol000.phase000.val.attrs['AXES']).split(','), - f.root.sol000.phase000.pol[:], - f.root.sol000.phase000.val[:], - f.root.sol000.phase000.weight[:], - f.root.sol000.amplitude000.val[:], - ) - - - def filter_stations(station_names): - """Generate indices of filtered stations""" - - if self.station == 'debug': - return list(range(len(station_names))) - - station_codes = ( - ('CS',) if self.station == 'dutch' else - ('RS',) if self.station == 'remote' else - ('RS', 'CS', 'ST') # i.e.: if self.station == 'international' - ) - return [ - i for i, station_name in enumerate(station_names) - if any(station_code in station_name for station_code in station_codes) - ] + axes = self.make_utf8(f.root.sol000.phase000.val.attrs['AXES']).split(',') + if 'pol' in axes: + return ( + [self.make_utf8(station) for station in f.root.sol000.antenna[:]['name']], + self.make_utf8(f.root.sol000.phase000.val.attrs['AXES']).split(','), + f.root.sol000.phase000.pol[:], + f.root.sol000.phase000.val[:], + f.root.sol000.phase000.weight[:], + f.root.sol000.amplitude000.val[:], + ) + elif 'amplitude000' in list(f.root.sol000._v_groups.keys()): + return ( + [self.make_utf8(station) for station in f.root.sol000.antenna[:]['name']], + self.make_utf8(f.root.sol000.phase000.val.attrs['AXES']).split(','), + ['XX'], + f.root.sol000.phase000.val[:], + f.root.sol000.phase000.weight[:], + f.root.sol000.amplitude000.val[:], + ) + else: + return ( + [self.make_utf8(station) for station in f.root.sol000.antenna[:]['name']], + self.make_utf8(f.root.sol000.phase000.val.attrs['AXES']).split(','), + ['XX'], + f.root.sol000.phase000.val[:], + f.root.sol000.phase000.weight[:], + np.ones(f.root.sol000.phase000.val.shape), + ) def filter_params(station_indices, axes, *parameters): return tuple( @@ -239,7 +264,7 @@ def weighted_vals(vals, weights): station_names1, axes1, phase_pols1, *params1 = extract_data(h5_1) - antenna_selection = functools.partial(filter_params, filter_stations(station_names1), axes1.index('ant')) + antenna_selection = functools.partial(filter_params, self.filter_stations(station_names1), axes1.index('ant')) phase_vals1, phase_weights1, amps1 = antenna_selection(*params1) prep_phase_score = weighted_vals(phase_vals1, phase_weights1) @@ -250,13 +275,15 @@ def weighted_vals(vals, weights): phase_vals2, phase_weights2, amps2 = antenna_selection(*params2) min_length = min(len(phase_pols1), len(phase_pols2)) - assert 0 < min_length < 2 + assert 0 < min_length <= 4 indices = [0] if min_length == 1 else [0, -1] - prep_phase_score, prep_amp_score, phase_vals2, phase_weights2, amps2 = filter_params( - indices, axes1.index('pol'), prep_phase_score, prep_amp_score, phase_vals2, phase_weights2, amps2 - ) + if 'pol' in axes1: + prep_phase_score, prep_amp_score, phase_vals2, phase_weights2, amps2 = filter_params( + indices, axes1.index('pol'), prep_phase_score, prep_amp_score, phase_vals2, phase_weights2, amps2 + ) + prep_phase_score = np.subtract(prep_phase_score, weighted_vals(phase_vals2, phase_weights2)) prep_amp_score = np.nan_to_num( np.divide(prep_amp_score, weighted_vals(amps2, phase_weights2)), @@ -270,8 +297,6 @@ def weighted_vals(vals, weights): def solution_stability(self): """ - #TODO: Under development - Get solution stability scores and make figure :return: bestcycle --> best cycle according to solutions @@ -279,12 +304,15 @@ def solution_stability(self): """ # loop over sources to get scores - assert len(self.sources) > 0 for k, source in enumerate(self.sources): + print(source) + sub_h5s = sorted([h5 for h5 in self.h5s if source in h5]) phase_scores = [] amp_scores = [] + + self.print_station_names(sub_h5s[0]) for m, sub_h5 in enumerate(sub_h5s): number = self.get_cycle_num(sub_h5) @@ -307,17 +335,27 @@ def solution_stability(self): self.make_figure(finalphase, finalamp, 'Phase stability', 'Amplitude stability', plotname) + # best cycle based on phase solution stability bestcycle = np.array(finalphase).argmin() self.writer.writerow(['phase', np.nan] + list(finalphase)) self.writer.writerow(['amp', np.nan] + list(finalamp)) - if len(finalphase) > 3: - phase_decrease, phase_quality, amp_quality = self.linreg_slope(finalphase[:4]), self.linreg_slope( - finalphase[-3:]), self.linreg_slope(finalamp[-3:]) - print(phase_decrease, phase_quality, amp_quality) - accept = phase_decrease < 0 and abs(phase_quality) < 0.05 and abs(amp_quality) < 0.05 + # selection based on slope + if len(finalphase) > 4: + phase_decrease, phase_quality, amp_quality = (self.linreg_slope(finalphase[1:4]), + self.linreg_slope(finalphase[-4:]), + self.linreg_slope(finalamp[-4:])) + if not all(v==0 for v in finalamp): + amp_decrease = self.linreg_slope([i for i in finalamp if i != 0][1:3]) + else: + amp_decrease = 0 + accept = (phase_decrease <= 0 + and amp_decrease <= 0 + and phase_quality <= 0.1 + and amp_quality <= 0.1 + and bestcycle > 0) return bestcycle, accept else: return None, False @@ -380,6 +418,8 @@ def get_minmax(inp: Union[str, np.ndarray]): @staticmethod def bilateral_filter(fitsfile=None, sigma_x=1, sigma_y=1, sigma_z=1, general_sigma=None): """ + #TODO: In development + Bilateral filter See: https://www.projectpro.io/recipes/what-is-bilateral-filtering-opencv @@ -460,16 +500,13 @@ def image_stability(self, bilateral_filter: bool = None): minmaxs)).slope # acceptance criteria - if minmax_slope > 0 and rms_slope > 0: - accept = False - elif minmax_slope_start > 0 or rms_slope_start > 0: - accept = False - elif len(self.fitsfiles) < 5: - accept = False - elif bestcycle + 2 < len(rmss): - accept = True - else: - accept = True + accept = ((minmax_slope <= 0 + or rms_slope <= 0) + and minmax_slope_start <= 0 + and rms_slope_start <= 0 + and len(self.fitsfiles) >= 5 + and bestcycle+2 < len(rmss) + and bestcycle > 1) return bestcycle - 1, accept @@ -482,9 +519,10 @@ def parse_args(): """ parser = ArgumentParser(description='Determine selfcal quality') - parser.add_argument('--selfcal_folder', required=True) + parser.add_argument('--fits', nargs='+', help='selfcal fits images', default=[]) + parser.add_argument('--h5', nargs='+', help='h5 solutions', default=[]) parser.add_argument('--bilateral_filter', action='store_true') - parser.add_argument('--station', type=str, default='dutch', choices=['dutch', 'remote', 'international', 'debug']) + parser.add_argument('--station', type=str, help='', default='international', choices=['dutch', 'remote', 'international', 'debug']) return parser.parse_args() @@ -494,19 +532,15 @@ def main(): """ args = parse_args() - sq = SelfcalQuality(args.selfcal_folder, args.station) - if len(sq.h5s) > 0 or len(sq.fitsfiles) > 0: - if len(sq.h5s) > 0: - bestcycle_solutions, accept_solutions = sq.solution_stability() - - print(f"Best cycle according to solutions {bestcycle_solutions}") - print(f"Accept according to solutions {accept_solutions}") - - if len(sq.fitsfiles) > 0: - bestcycle_image, accept_image = sq.image_stability(bilateral_filter=args.bilateral_filter) + sq = SelfcalQuality(args.h5, args.fits, args.station) + if len(sq.h5s) > 1 and len(sq.fitsfiles) > 1: + bestcycle_solutions, accept_solutions = sq.solution_stability() + bestcycle_image, accept_image = sq.image_stability(bilateral_filter=args.bilateral_filter) - print(f"Best cycle according to image {bestcycle_image}") - print(f"Accept according to image {accept_image}") + print(f"Best cycle according to image {bestcycle_image}") + print(f"Accept according to image {accept_image}") + print(f"Best cycle according to solutions {bestcycle_solutions}") + print(f"Accept according to solutions {accept_solutions}") sq.textfile.close() df = pd.read_csv(f'selfcal_performance.csv').set_index('solutions').T