diff --git a/__init__.py b/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/catalogue_helpers/__init__.py b/catalogue_helpers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/catalogue_helpers/find_sources.py b/catalogue_helpers/find_sources.py deleted file mode 100644 index 910afd49..00000000 --- a/catalogue_helpers/find_sources.py +++ /dev/null @@ -1,515 +0,0 @@ -""" -This script runs pybdsf on a fits file image to extract sources and components. -From the source table, it makes cut out images of all sources (fits and png images) for inspection. - -This has been used for catalogue reduction of the ELAIS-N1 field. -Feel free to adapt to your own needs. (Watch out for hardcoded parameters or paths) - -The output contains the following folders: -- bright_sources --> sources considered to be bright and probably compact -- weak_sources --> fainter sources -- cluster_sources --> sources clustered together (important for source association) - - -""" - -import bdsf -import argparse -from astropy.nddata import Cutout2D -import numpy as np -import astropy.units as u -from astropy.io import fits -from astropy.wcs.utils import skycoord_to_pixel -from astropy.coordinates import SkyCoord -from astropy.wcs import WCS -import matplotlib.pyplot as plt -from matplotlib.colors import SymLogNorm, PowerNorm -from astropy.visualization.wcsaxes import WCSAxes -import matplotlib.path as mpath -import matplotlib.patches as patches -from matplotlib import ticker -import os -from astropy.table import Table -import pyregion -from scipy.cluster.hierarchy import linkage, fcluster -from scipy.spatial import distance -from glob import glob - - -def get_rms(image_data): - """ - from Cyril Tasse/kMS - - :param image_data: image data array - :return: rms (noise measure) - """ - from past.utils import old_div - - maskSup = 1e-7 - m = image_data[np.abs(image_data) > maskSup] - rmsold = np.std(m) - diff = 1e-1 - cut = 3. - med = np.median(m) - for _ in range(10): - ind = np.where(np.abs(m - med) < rmsold * cut)[0] - rms = np.std(m[ind]) - if np.abs(old_div((rms - rmsold), rmsold)) < diff: break - rmsold = rms - print(f'Noise : {str(round(rms * 1000, 4))} {u.mJy / u.beam}') - return rms - - -def get_beamarea(hdu): - """ - Get beam area in pixels - """ - - bmaj = hdu[0].header['BMAJ'] - bmin = hdu[0].header['BMIN'] - - beammaj = bmaj / (2.0 * (2 * np.log(2)) ** 0.5) # Convert to sigma - beammin = bmin / (2.0 * (2 * np.log(2)) ** 0.5) # Convert to sigma - pixarea = abs(hdu[0].header['CDELT1'] * hdu[0].header['CDELT2']) - - beamarea = 2 * np.pi * 1.0 * beammaj * beammin # Note that the volume of a two dimensional gaus$ - beamarea_pix = beamarea / pixarea - - return beamarea_pix - - -def make_cutout(fitsfile=None, pos: tuple = None, size: tuple = (1000, 1000), savefits=None): - """ - Make cutout from your image with this method. - pos (tuple) -> position in pixels - size (tuple) -> size of your image in pixel size, default=(1000,1000) - """ - fts = fits.open(fitsfile) - image_data = fts[0].data - wcs = WCS(fts[0].header, naxis=2) - - while image_data.ndim>2: - image_data = image_data[0] - - out = Cutout2D( - data=image_data, - position=pos, - size=size, - wcs=wcs, - mode='partial' - ) - - wcs = out.wcs - header = wcs.to_header() - - image_data = out.data - # rms = get_rms(image_data) - # hdu = [fits.PrimaryHDU(image_data, header=header)] - - if savefits: - image_data = out.data - image_data = np.expand_dims(np.expand_dims(image_data, axis=0), axis=0) - fits.writeto(savefits, image_data, header, overwrite=True) - - return image_data, header - - -def make_image(fitsfiles, cmap: str = 'RdBu_r', components: str = None): - """ - Image your data with this method. - fitsfiles -> list with fits file - cmap -> choose your preferred cmap - """ - - def fixed_color(shape, saved_attrs): - from pyregion.mpl_helper import properties_func_default - attr_list, attr_dict = saved_attrs - attr_dict["color"] = 'green' - kwargs = properties_func_default(shape, (attr_list, attr_dict)) - return kwargs - - if len(fitsfiles)==1: - fitsfile = fitsfiles[0] - - hdu = fits.open(fitsfile) - image_data = hdu[0].data - while image_data.ndim > 2: - image_data = image_data[0] - header = hdu[0].header - - rms = get_rms(image_data) - vmin = rms - vmax = rms * 9 - - wcs = WCS(header, naxis=2) - - fig = plt.figure(figsize=(7, 10), dpi=200) - plt.subplot(projection=wcs) - WCSAxes(fig, [0.1, 0.1, 0.8, 0.8], wcs=wcs) - im = plt.imshow(image_data, origin='lower', cmap=cmap) - im.set_norm(PowerNorm(gamma=0.5, vmin=vmin, vmax=vmax)) - plt.xlabel('Right Ascension (J2000)', size=14) - plt.ylabel('Declination (J2000)', size=14) - plt.tick_params(axis='both', which='major', labelsize=12) - - - if components is not None: - r = pyregion.open(components).as_imagecoord(header=hdu[0].header) - patch_list, artist_list = r.get_mpl_patches_texts(fixed_color) - - # fig.add_axes(ax) - for patch in patch_list: - plt.gcf().gca().add_patch(patch) - for artist in artist_list: - plt.gca().add_artist(artist) - - orientation = 'horizontal' - ax_cbar1 = fig.add_axes([0.22, 0.15, 0.73, 0.02]) - cb = plt.colorbar(im, cax=ax_cbar1, orientation=orientation) - cb.set_label('Surface brightness [mJy/beam]', size=16) - cb.ax.tick_params(labelsize=16) - - cb.outline.set_visible(False) - - # Extend colorbar - bot = -0.05 - top = 1.05 - - # Upper bound - xy = np.array([[0, 1], [0, top], [1, top], [1, 1]]) - if orientation == "horizontal": - xy = xy[:, ::-1] - - Path = mpath.Path - - # Make Bezier curve - curve = [ - Path.MOVETO, - Path.CURVE4, - Path.CURVE4, - Path.CURVE4, - ] - - color = cb.cmap(cb.norm(cb._values[-1])) - patch = patches.PathPatch( - mpath.Path(xy, curve), - facecolor=color, - linewidth=0, - antialiased=False, - transform=cb.ax.transAxes, - clip_on=False, - ) - cb.ax.add_patch(patch) - - # Lower bound - xy = np.array([[0, 0], [0, bot], [1, bot], [1, 0]]) - if orientation == "horizontal": - xy = xy[:, ::-1] - - color = cb.cmap(cb.norm(cb._values[0])) - patch = patches.PathPatch( - mpath.Path(xy, curve), - facecolor=color, - linewidth=0, - antialiased=False, - transform=cb.ax.transAxes, - clip_on=False, - ) - cb.ax.add_patch(patch) - - # Outline - xy = np.array( - [[0, 0], [0, bot], [1, bot], [1, 0], [1, 1], [1, top], [0, top], [0, 1], [0, 0]] - ) - if orientation == "horizontal": - xy = xy[:, ::-1] - - Path = mpath.Path - - curve = [ - Path.MOVETO, - Path.CURVE4, - Path.CURVE4, - Path.CURVE4, - Path.LINETO, - Path.CURVE4, - Path.CURVE4, - Path.CURVE4, - Path.LINETO, - ] - path = mpath.Path(xy, curve, closed=True) - - patch = patches.PathPatch( - path, facecolor="None", lw=1, transform=cb.ax.transAxes, clip_on=False - ) - cb.ax.add_patch(patch) - tick_locator = ticker.MaxNLocator(nbins=3) - cb.locator = tick_locator - cb.update_ticks() - - fig.tight_layout(pad=1.0) - plt.grid(False) - plt.grid('off') - plt.savefig(fitsfile.replace('.fits', '.png'), dpi=250, bbox_inches='tight') - plt.close() - - hdu.close() - - else: - - for n, fitsfile in enumerate(fitsfiles): - - hdu = fits.open(fitsfile) - header = hdu[0].header - - - if n==0: - cdelt = abs(header['CDELT2']) - w = WCS(header, naxis=2) - - fig, axs = plt.subplots(2, 2, - figsize=(10, 8), - subplot_kw={'projection': w}) - - imdat = hdu[0].data - while imdat.ndim > 2: - imdat = imdat[0] - or_shape = imdat.shape - skycenter = w.pixel_to_world(header['NAXIS1']//2, header['NAXIS2']//2) - rms = get_rms(imdat) - ax = plt.subplot(220 + n+1, projection=w) - - if components is not None: - r = pyregion.open(components).as_imagecoord(header=header) - patch_list, artist_list = r.get_mpl_patches_texts(fixed_color) - - # fig.add_axes(ax) - for patch in patch_list: - ax.add_patch(patch) - for artist in artist_list: - ax.add_artist(artist) - - else: - pixfact = cdelt/abs(header['CDELT2']) - shape = np.array(or_shape) * pixfact - center_sky = SkyCoord(f'{skycenter.ra.value}deg', f'{skycenter.dec.value}deg', frame='icrs') - - w = WCS(header, naxis=2) - pix_coord = skycoord_to_pixel(center_sky, w, 0, 'all') - imdat, h = make_cutout(fitsfile=fitsfile, - pos=tuple([int(p) for p in pix_coord]), - size=tuple([int(p) for p in shape])) - w = WCS(h, naxis=2) - ax = plt.subplot(220 + n+1, projection=w) - - while imdat.ndim > 2: - imdat = imdat[0] - - imdat *= 1000 - - rms = get_rms(imdat) - vmin = rms - vmax = rms * 9 - - - im = ax.imshow(imdat, origin='lower', cmap=cmap, norm=PowerNorm(gamma=0.5, vmin=vmin, vmax=vmax)) - ax.set_xlabel('Right Ascension (J2000)', size=12) - ax.set_ylabel('Declination (J2000)', size=12) - if n!=0: - ax.set_title(fitsfile.split('/')[-2].replace('_', ' ')) - - cb = fig.colorbar(im, ax=ax, orientation='horizontal', shrink=0.65, pad=0.15) - cb.set_label('Surface brightness [mJy/beam]', size=12) - cb.ax.tick_params(labelsize=12) - - fig.tight_layout(pad=1.0) - plt.grid(False) - plt.grid('off') - plt.savefig(fitsfiles[0].replace('.fits', '.png'), dpi=250, bbox_inches='tight') - - -def run_pybdsf(fitsfile, rmsbox): - """ - Run pybdsf - - :param fitsfile: fits file - :param rmsbox: rms box first parameter - - :return: source catalogue - """ - - prefix = fitsfile.replace('.fits', '') - img = bdsf.process_image(fitsfile, - thresh_isl=3, - thresh_pix=5, - atrous_do=True, - rms_box=(int(rmsbox), int(rmsbox // 8)), - rms_box_bright=(int(rmsbox//3), int(rmsbox//12)), - adaptive_rms_box=True, - group_tol=10.0) # , rms_map=True, rms_box = (160,40)) - - img.write_catalog(clobber=True, outfile=prefix + '_source_catalog.fits', format='fits', catalog_type='srl') - img.write_catalog(clobber=True, outfile=prefix + '_gaussian_catalog.fits', format='fits', catalog_type='gaul') - for type in ['island_mask', 'gaus_model', 'gaus_resid', 'mean', 'rms']: - img.export_image(clobber=True, img_type=type, outfile=prefix + f'_{type}.fits') - return prefix + '_source_catalog.fits' - - -def get_pix_coord(table): - """ - Get pixel coordiantes from table RA/DEC - :param table: fits table - :return: pixel coordinates - """ - - f = fits.open(table) - t = f[1].data - # res = t[t['S_Code'] != 'S'] - pos = list(zip(t['RA'], t['DEC'], t['Source_id'])) - pos = [[SkyCoord(f'{c[0]}deg', f'{c[1]}deg', frame='icrs'), c[2]] for c in pos] - fts = fits.open(table.replace("_source_catalog", "")) - w = WCS(fts[0].header, naxis=2) - pix_coord = [([int(c) for c in skycoord_to_pixel(sky[0], w, 0, 'all')], sky[1]) for sky in pos] - return pix_coord - - -def make_point_file(t): - """ - Make ds9 file with ID in it - """ - - header = """# Region file format: DS9 version 4.1 -global color=green dashlist=8 3 width=1 font="helvetica 10 normal roman" select=1 highlite=1 dash=0 fixed=0 edit=1 move=1 delete=1 include=1 source=1 -fk5 -""" - t = Table.read(t, format='fits') - file = open('components.reg', 'w') - file.write(header) - for n, c in enumerate(zip(t['RA'], t['DEC'])): - file.write(f'\n# text({c[0]},{c[1]}) text=' + '{' + f'{t["Source_id"][n]}' + '}') - return - - -def get_table_index(t, source_id): - return int(np.argwhere(t['Source_id'] == source_id).squeeze()) - - -def get_clusters_ra_dec(t, deg_dist=0.003): - """ - Get clusters of sources based on euclidean distance - """ - ra_dec = np.stack((list(t['RA']),list(t['DEC'])),axis=1) - Z = linkage(ra_dec, method='complete', metric='euclidean') - return fcluster(Z, deg_dist, criterion='distance') - -def get_clusters_pix(pixcoor, pix_dist=100): - """ - Get clusters of sources based on euclidean distance - """ - pixpos = np.stack((pixcoor[:,0],pixcoor[:,1]),axis=1) - Z = linkage(pixpos, method='complete', metric='euclidean') - return fcluster(Z, pix_dist, criterion='distance') - - -def cluster_idx(clusters, idx): - return np.argwhere(clusters==clusters[idx]).squeeze(axis=1) - - -def max_dist(coordinates): - """ - Get the longest distance between coordinates - - :param coordinates: indices from table - """ - return np.max(distance.cdist(coordinates, coordinates, 'euclidean')) - - -def parse_args(): - """ - Parse input arguments - """ - - parser = argparse.ArgumentParser(description='Source detection') - parser.add_argument('--rmsbox', type=int, help='rms box pybdsf', default=120) - parser.add_argument('--no_pybdsf', action='store_true', help='Skip pybdsf') - parser.add_argument('--comparison_plots', nargs='+', help='Add fits files to compare with, ' - 'with same field coverage', default=[]) - parser.add_argument('fits', nargs='+', help='fits files') - return parser.parse_args() - - -def main(): - """ - Main function - """ - - args = parse_args() - for m, fts in enumerate(args.fits): - if not args.no_pybdsf: - tbl = run_pybdsf(fts, args.rmsbox) - else: - tbl = fts.replace('.fits', '') + '_source_catalog.fits' - - # make ds9 region file with sources in it - make_point_file(tbl) - # loop through resolved sources and make images - coord = get_pix_coord(tbl) - - T = Table.read(tbl, format='fits') - T['Peak_flux_min'] = T['Peak_flux'] - T['E_Peak_flux'] - T['Total_flux_min'] = T['Total_flux'] - T['E_Total_flux'] - T['Total_flux_over_peak_flux'] = T['Total_flux'] / T['Peak_flux'] - - f = fits.open(fts) - pixscale = np.sqrt(abs(f[0].header['CDELT2']*f[0].header['CDELT1'])) - beamarea = get_beamarea(f) - f.close() - - clusters_small = get_clusters_ra_dec(T, pixscale*100) - clusters_large = get_clusters_ra_dec(T, max(0.01, pixscale*100)) - - os.system('mkdir -p bright_sources') - os.system('mkdir -p weak_sources') - os.system('mkdir -p cluster_sources') - - to_ignore = [] - - for c, n in coord: - table_idx = get_table_index(T, n) - if table_idx in to_ignore: - continue - - rms = T[T['Source_id'] == n]['Isl_rms'][0] - cluster_indices_large = cluster_idx(clusters_large, table_idx) - clusters_indices_small = cluster_idx(clusters_small, table_idx) - - if len(cluster_indices_large) > 3: - cluster_indices = cluster_indices_large - else: - cluster_indices = clusters_indices_small - cluster_indices = [idx for idx in cluster_indices if idx not in to_ignore] - - if len(cluster_indices) > 1: - pix_coord = np.array([p[0] for p in coord])[cluster_indices] - imsize = max(int(max_dist(pix_coord)*3), 150) - idxs = '-'.join([str(p) for p in cluster_indices]) - make_cutout(fitsfile=fts, pos=tuple(c), size=(imsize, imsize), savefits=f'cluster_sources/source_{m}_{idxs}.fits') - make_image([f'cluster_sources/source_{m}_{idxs}.fits']+args.comparison_plots, 'RdBu_r', 'components.reg') - for i in cluster_indices: - to_ignore.append(i) - - elif T[T['Source_id'] == n]['Peak_flux_min'][0] < 1.5*rms or \ - T[T['Source_id'] == n]['Peak_flux'][0] < 5.5*rms or \ - T[T['Source_id'] == n]['Total_flux'] < 7*rms: - - make_cutout(fitsfile=fts, pos=tuple(c), size=(300, 300), savefits=f'weak_sources/source_{m}_{n}.fits') - make_image([f'weak_sources/source_{m}_{n}.fits']+args.comparison_plots, 'RdBu_r', 'components.reg') - - else: - make_cutout(fitsfile=fts, pos=tuple(c), size=(300, 300), savefits=f'bright_sources/source_{m}_{n}.fits') - make_image([f'bright_sources/source_{m}_{n}.fits']+args.comparison_plots, 'RdBu_r', 'components.reg') - - - -if __name__ == '__main__': - main() diff --git a/catalogue_helpers/run_pybdsf.py b/catalogue_helpers/pybdsf_experiments/run_pybdsf.py similarity index 100% rename from catalogue_helpers/run_pybdsf.py rename to catalogue_helpers/pybdsf_experiments/run_pybdsf.py diff --git a/ds9_helpers/__init__.py b/ds9_helpers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ds9_helpers/select_facets.py b/ds9_helpers/select_facets.py deleted file mode 100644 index 2458f505..00000000 --- a/ds9_helpers/select_facets.py +++ /dev/null @@ -1,132 +0,0 @@ -import pandas as pd -import numpy as np -from math import radians, cos, sin, sqrt, atan2 -from scipy.spatial import Voronoi, voronoi_plot_2d -import matplotlib.cm as cm -import matplotlib.pyplot as plt - - -def angular_distance(ra1_deg, dec1_deg, ra2_deg, dec2_deg): - """ - Function to calculate angular distance between two points with RA/DEC - - :param ra1_deg: RA first source - :param dec1_deg: DEC first source - :param ra2_deg: RA second source - :param dec2_deg: DEC second source - - :return: distance in degrees - - """ - # Convert degrees to radians - ra1 = radians(ra1_deg) - dec1 = radians(dec1_deg) - ra2 = radians(ra2_deg) - dec2 = radians(dec2_deg) - - # Haversine-like formula for angular distance in degrees - delta_ra = ra2 - ra1 - delta_dec = dec2 - dec1 - - a = sin(delta_dec / 2) ** 2 + cos(dec1) * cos(dec2) * sin(delta_ra / 2) ** 2 - c = 2 * atan2(sqrt(a), sqrt(1 - a)) - - # Convert back to degrees - distance_deg = np.degrees(c) - return distance_deg - - -def remove_close_pairs(csv_file, dist_offset=0.1): - """ - Remove close pairs in CSV and sort by scalarphasediff score - - Args: - csv_file: CSV file with sources RA/DEC and spd_score - dist_offset: min offset in degrees - - Returns: - cleaned df - """ - - # Load the CSV file - df = pd.read_csv(csv_file).sort_values('spd_score') - - # Assuming the CSV file has columns 'RA' and 'DEC' in degrees - ra = df['RA'].values - dec = df['DEC'].values - - # List to store pairs and their distances - pairs = [] - - # Calculate pairwise distances - for i in range(len(ra)): - for j in range(i + 1, len(ra)): - dist = angular_distance(ra[i], dec[i], ra[j], dec[j]) - if dist < dist_offset: - pairs.append((i, j, dist)) - - # Convert to a DataFrame to display - duplicate_df = pd.DataFrame(pairs, columns=['Index 1', 'Index 2', 'Distance (degrees)']) - - for idx in duplicate_df['Index 2'].values[::-1]: - df.drop(idx, inplace=True) - - return df - - -def plot_voronoi(csv_path, ra_column='RA', dec_column='DEC', score_column='spd_score'): - """ - Plots a Voronoi diagram where the points (RA/DEC) are colored based on the values in the score_column. - - Parameters: - csv_path (str): Path to the CSV file. - ra_column (str): Column name for RA values. - dec_column (str): Column name for DEC values. - score_column (str): Column name for the score values used for coloring the points. - """ - # Load the CSV data - data = pd.read_csv(csv_path).sort_values(score_column) - - # Extract RA, DEC, and spd_scores columns - ra = data[ra_column].values - dec = data[dec_column].values - spd_scores = data[score_column].values - - # Combine RA and DEC into coordinate pairs - points = np.column_stack((ra, dec)) - - # Compute Voronoi tessellation - vor = Voronoi(points) - - # Create a plot - fig, ax = plt.subplots(figsize=(10, 8)) - - # Plot the Voronoi diagram without filling the regions - voronoi_plot_2d(vor, ax=ax, show_vertices=False, line_colors='black', line_width=2, line_alpha=0.6) - - # Normalize spd_scores for coloring - norm = plt.Normalize(vmin=0, vmax=2) - cmap = cm.viridis - - # Scatter plot the points, colored by spd_scores - sc = ax.scatter(ra[1:], dec[1:], c=spd_scores[1:], cmap=cmap, norm=norm, edgecolor='black', s=200, zorder=2) - - # Highlight the first point with a red star - ax.scatter(ra[0], dec[0], color='red', marker='*', s=600, zorder=3, edgecolor='black') - - - ax.tick_params(labelsize=16) - - # Add a colorbar - # Add a colorbar and set font sizes - cbar = plt.colorbar(sc) - cbar.set_label('$\hat{\sigma}_{c}$ (rad)', fontsize=16) # Set the font size for the colorbar label - cbar.ax.tick_params(labelsize=16) # Set the font size for colorbar ticks - plt.gca().invert_xaxis() - - # Set labels and title - ax.set_xlabel('RA (degrees)', fontsize=16) - ax.set_ylabel('DEC (degrees)', fontsize=16) - - plt.tight_layout() - plt.savefig('voronoi_calibrators.png', dpi=150) \ No newline at end of file diff --git a/ds9_helpers/split_polygon_facets.py b/ds9_helpers/split_polygon_facets.py index 065f66ae..34f5435f 100644 --- a/ds9_helpers/split_polygon_facets.py +++ b/ds9_helpers/split_polygon_facets.py @@ -143,9 +143,9 @@ def main(): H.close() #TODO: dangerous! Convert to degrees - if np.all(np.abs(dirs)<2*np.pi): - dirs %= (2*np.pi) - dirs *= 360/(2*np.pi) + if np.all(np.abs(dirs) < 2 * np.pi): + # Converting radians to degrees + dirs = np.degrees(dirs) f = open('polygon_info.csv', 'w') writer = csv.writer(f) diff --git a/fits_helpers/__init__.py b/fits_helpers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fits_helpers/crop_nan_boundaries.py b/fits_helpers/crop_nan_boundaries.py index 83b50f65..9fc8122b 100644 --- a/fits_helpers/crop_nan_boundaries.py +++ b/fits_helpers/crop_nan_boundaries.py @@ -43,14 +43,18 @@ def parse_args(): :return: parsed arguments """ parser = ArgumentParser(description='Crop fits file with nan boundaries') - parser.add_argument('--fits_input', help='fits input file', required=True, type=str) - parser.add_argument('--fits_output', help='fits output file', required=True, type=str) + parser.add_argument('fits_in', help='fits input file', type=str) + parser.add_argument('--output_name', help='fits output file', type=str) return parser.parse_args() def main(): """ Main function""" args = parse_args() - crop_nan_boundaries(args.fits_input, args.fits_output) + if args.output_name is None: + outname = args.fits_in + else: + outname = args.output_name + crop_nan_boundaries(args.fits_in, outname) if __name__ == '__main__': main() \ No newline at end of file diff --git a/fits_helpers/cropy_2048.py b/fits_helpers/cropy_2048.py index 7e5c5ab1..4c50f7d3 100644 --- a/fits_helpers/cropy_2048.py +++ b/fits_helpers/cropy_2048.py @@ -1,5 +1,10 @@ +""" +Crop image to 2048x2048 (for neural network) +""" + from astropy.io import fits import numpy as np +from argparse import ArgumentParser def crop_fits_image(input_filename, output_filename, center=None): @@ -55,3 +60,22 @@ def crop_fits_image(input_filename, output_filename, center=None): # Write the cropped image to the output file hdu = fits.PrimaryHDU(data=np.array([[cropped_data]]), header=header) hdu.writeto(output_filename, overwrite=True) + + +def parse_args(): + """ + Command line argument parser + :return: parsed arguments + """ + parser = ArgumentParser(description='Crop image to 2048x2048') + parser.add_argument('--fits_input', help='fits input file', required=True, type=str) + parser.add_argument('--fits_output', help='fits output file', required=True, type=str) + return parser.parse_args() + +def main(): + """ Main function""" + args = parse_args() + crop_fits_image(args.fits_input, args.fits_output) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/fits_helpers/cut_fitsfile.py b/fits_helpers/cut_fits_with_region.py similarity index 81% rename from fits_helpers/cut_fitsfile.py rename to fits_helpers/cut_fits_with_region.py index c8005028..0c19d650 100644 --- a/fits_helpers/cut_fitsfile.py +++ b/fits_helpers/cut_fits_with_region.py @@ -12,8 +12,8 @@ def parse_args(): :return: parsed arguments """ parser = ArgumentParser(description='Cut fits file with region file') - parser.add_argument('--fits_input', help='fits input file', required=True, type=str) - parser.add_argument('--fits_output', help='fits output file', required=True, type=str) + parser.add_argument('fits_input', help='fits input file', type=str) + parser.add_argument('--output_name', help='fits output file', type=str) parser.add_argument('--region', help='region file', required=True, type=str) return parser.parse_args() @@ -24,7 +24,9 @@ def main(): fitsfile = args.fits_input regionfile = args.region - outputfits = args.fits_output + outputfits = args.output_name + if outputfits is None: + outputfits = fitsfile hdu = fits.open(fitsfile) diff --git a/fits_helpers/get_beam_rms_facet.py b/fits_helpers/get_beam_rms_facet.py deleted file mode 100644 index b316b29f..00000000 --- a/fits_helpers/get_beam_rms_facet.py +++ /dev/null @@ -1,61 +0,0 @@ -from glob import glob -import numpy as np -from astropy.io import fits -import astropy.units as u -from past.utils import old_div - -def rms(image_data): - """ - from Cyril Tasse/kMS - - :param image_data: image data array - :return: rms (noise measure) - """ - - maskSup = 1e-7 - m = image_data[np.abs(image_data)>maskSup] - rmsold = np.std(m) - diff = 1e-1 - cut = 3. - med = np.median(m) - for _ in range(10): - ind = np.where(np.abs(m - med) < rmsold*cut)[0] - rms = np.std(m[ind]) - if np.abs(old_div((rms-rmsold), rmsold)) < diff: break - rmsold = rms - print(f'Noise : {str(round(rms * 1000, 4))} {u.mJy/u.beam}') - return rms - -print("0.3''") -for fts in sorted(glob('/project/lofarvwf/Share/jdejong/output/ELAIS/ALL_L/imaging/split_facets2/final_img_0.3_rms_old/facet_*.fits')): - print(fts) - f = fits.open(fts)[0] - print(f"{f.header['BMIN']*3600}'' x {f.header['BMAJ']*3600}''") - data = f.data - pixelsize = f.header['CDELT1'] ** 2 - print(f'Size: {len(data[data == data]) * pixelsize} degree^2') - rms(data) - print() - -print("0.6''") -for fts in sorted(glob('/project/lofarvwf/Share/jdejong/output/ELAIS/ALL_L/imaging/split_facets2/final_img_0.6_rms_old/facet_*.fits')): - print(fts) - f = fits.open(fts)[0] - print(f"{f.header['BMIN']*3600}'' x {f.header['BMAJ']*3600}''") - data = f.data - pixelsize = f.header['CDELT1'] ** 2 - print(f'Size: {len(data[data == data]) * pixelsize} degree^2') - rms(data) - print() - -print("1.2''") -for fts in sorted(glob('/project/lofarvwf/Share/jdejong/output/ELAIS/ALL_L/imaging/split_facets2/final_img_1.2_rms_old/facet_?.fits') + - glob('/project/lofarvwf/Share/jdejong/output/ELAIS/ALL_L/imaging/split_facets2/final_img_1.2_rms_old/facet_??.fits')): - print(fts) - f = fits.open(fts)[0] - print(f"{round(f.header['BMIN']*3600, 2)}'' x {round(f.header['BMAJ']*3600,2)}''") - pixelsize = f.header['CDELT1'] ** 2 - data = f.data - print(f'Size: {len(data[data == data]) * pixelsize} degree^2') - rms(data) - print() diff --git a/fits_helpers/make_cutouts.py b/fits_helpers/image_utils/make_cutouts.py similarity index 100% rename from fits_helpers/make_cutouts.py rename to fits_helpers/image_utils/make_cutouts.py diff --git a/fits_helpers/power_spec_simple.py b/fits_helpers/image_utils/power_spec_simple.py similarity index 100% rename from fits_helpers/power_spec_simple.py rename to fits_helpers/image_utils/power_spec_simple.py diff --git a/fits_helpers/plot_baseline_track.py b/fits_helpers/plot_baseline_track.py deleted file mode 100644 index e9a9a526..00000000 --- a/fits_helpers/plot_baseline_track.py +++ /dev/null @@ -1,109 +0,0 @@ -from casacore.tables import table -import sys -import matplotlib.pyplot as plt -from glob import glob -import os - -# Set the MPLCONFIGDIR environment variable -os.system('mkdir -p ~/matplotlib_cache') -os.environ['MPLCONFIGDIR'] = os.path.expanduser('~/matplotlib_cache') - - -def get_station_id(ms): - """ - Get station with corresponding id number - - :param: - - ms: measurement set - - :return: - - antenna names, IDs - """ - - t = table(ms+'::ANTENNA', ack=False) - ants = t.getcol("NAME") - t.close() - - t = table(ms+'::FEED', ack=False) - ids = t.getcol("ANTENNA_ID") - t.close() - - return ants, ids - - -def plot_baseline_track(t_final_name: str = None, t_input_names: list = None, baseline='0-1', UV=True, saveas=None): - """ - Plot baseline track - - :param: - - t_final_name: table with final name - - t_input_names: tables to compare with - - mappingfiles: baseline mapping files - """ - - if len(t_input_names) > 4: - sys.exit("ERROR: Can just plot 4 inputs") - - colors = ['red', 'green', 'yellow', 'black'] - - if not UV: - print("MAKE UW PLOT") - - ant1, ant2 = baseline.split('-') - - for n, t_input_name in enumerate(t_input_names): - - ref_stats, ref_ids = get_station_id(t_final_name) - new_stats, new_ids = get_station_id(t_input_name) - - id_map = dict(zip([ref_stats.index(a) for a in new_stats], new_ids)) - - with table(t_final_name, ack=False) as f: - fsub = f.query(f'ANTENNA1={ant1} AND ANTENNA2={ant2} AND NOT ALL(WEIGHT_SPECTRUM == 0)', columns='UVW') - uvw1 = fsub.getcol("UVW") - - with table(t_input_name, ack=False) as f: - fsub = f.query(f'ANTENNA1={id_map[int(ant1)]} AND ANTENNA2={id_map[int(ant2)]} AND NOT ALL(WEIGHT_SPECTRUM == 0)', columns='UVW') - uvw2 = fsub.getcol("UVW") - - # Scatter plot for uvw1 - if n == 0: - lbl = 'Final MS' - else: - lbl = None - - plt.scatter(uvw1[:, 0], uvw1[:, 2] if UV else uvw1[:, 3], label=lbl, color='blue', edgecolor='black', alpha=0.2, s=100, marker='o') - - # Scatter plot for uvw2 - plt.scatter(uvw2[:, 0], uvw2[:, 2] if UV else uvw2[:, 3], label=f'Dataset {n}', color=colors[n], edgecolor='black', alpha=0.7, s=40, marker='*') - - - # Adding labels and title - plt.xlabel("U (m)", fontsize=14) - plt.ylabel("V (m)" if UV else "W (m)", fontsize=14) - - # Adding grid - plt.grid(True, linestyle='--', alpha=0.6) - - # Adding legend - plt.legend(fontsize=12) - - plt.xlim(561260, 563520) - plt.ylim(192782, 194622) - - plt.tight_layout() - - if saveas is None: - plt.show() - else: - plt.savefig(saveas, dpi=150) - plt.close() - - -#####TEST###### -plot_baseline_track('test_1.ms', sorted(glob('a*.ms')), '0-71', saveas='test_1.png') -plot_baseline_track('test_2.ms', sorted(glob('a*.ms')), '0-71', saveas='test_2.png') -plot_baseline_track('test_3.ms', sorted(glob('a*.ms')), '0-71', saveas='test_3.png') -plot_baseline_track('test_4.ms', sorted(glob('a*.ms')), '0-71', saveas='test_4.png') -plot_baseline_track('test_6.ms', sorted(glob('a*.ms')), '0-71', saveas='test_6.png') -plot_baseline_track('test_8.ms', sorted(glob('a*.ms')), '0-71', saveas='test_8.png') \ No newline at end of file diff --git a/h5_helpers/__init__.py b/h5_helpers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/h5_helpers/close_h5.py b/h5_helpers/close_h5.py index 230ee30f..f5be0c91 100644 --- a/h5_helpers/close_h5.py +++ b/h5_helpers/close_h5.py @@ -24,3 +24,7 @@ def force_close(h5): h.close() return sys.stderr.write(h5 + ' not found\n') + + +if __name__ == '__main__': + force_close_all() \ No newline at end of file diff --git a/h5_helpers/h5_filter.py b/h5_helpers/h5_filter.py deleted file mode 100644 index 324fbd36..00000000 --- a/h5_helpers/h5_filter.py +++ /dev/null @@ -1,170 +0,0 @@ -""" -This script is meant to filter directions outside of a given circle. -This can be useful for excluding directions at the boundaries of a field because the solutions are less good there. -""" - -__author__ = "Jurjen de Jong (jurjendejong@strw.leidenuniv.nl)" - -import tables -from astropy.wcs import WCS -from astropy.io import fits -from argparse import ArgumentParser, ArgumentTypeError -from math import pi, cos, sin, acos -from losoto.h5parm import h5parm -from numpy import ones, zeros, unique -import os -from glob import glob -import re - - -def str2bool(v): - v = str(v) - if v.lower() in ('yes', 'true', 't', 'y', '1'): - return True - elif v.lower() in ('no', 'false', 'f', 'n', '0'): - return False - else: - raise ArgumentTypeError('Boolean value expected.') - - -def degree_to_radian(inp): - """degree to radian""" - return float(inp) / 360 * pi * 2 - - -def radian_to_degree(inp): - """radion to degree""" - return float(inp) * 360 / (pi * 2) - - -def angular_distance(p1, p2): - """angular distance for points in ra and dec in degrees""" - if p1[0] > 2 * pi: - p1 = [degree_to_radian(p) for p in p1] - p2 = [degree_to_radian(p) for p in p2] - return radian_to_degree(acos(sin(p1[1]) * sin(p2[1]) + cos(p1[1]) * cos(p2[1]) * cos(p1[0] - p2[0]))) - - -def remove_numbers(inp): - return "".join(re.findall("[a-zA-z]+", inp)) - - -def create_new_soltab(h5_in_name, h5_out_name, directions, sources): - """ - Create a new dataset in the h5 table - :param filename: name of ourput file - :param solset: solution set name - :param soltab: solution table name - :param dirs: directions to include - """ - - h5_in = h5parm(h5_in_name, readonly=True) - h5_out = h5parm(h5_out_name, readonly=False) - for ss in h5_in.getSolsetNames(): - - if ss in h5_out.getSolsetNames(): - solsetout = h5_out.getSolset(ss) - else: - solsetout = h5_out.makeSolset(ss) - - current_sources = [source[0].decode('UTF-8') for source in solsetout.obj.source[:]] - new_sources = [source for source in sources if source[0].decode('UTF-8') not in current_sources] - new_sources = [(bytes('Dir' + str(n).zfill(2), 'utf-8'), ns[1]) for n, ns in enumerate(new_sources)] - if len(new_sources) > 0: - solsetout.obj.source.append(new_sources) - - for st in h5_in.getSolset(ss).getSoltabNames(): - print('Filter {solset}/{soltab} from {h5_in} into {h5_out}'.format(solset=ss, soltab=st, - h5_in=h5_in_name.split('/')[-1], - h5_out=h5_out_name.split('/')[-1])) - - solutiontable = h5_in.getSolset(ss).getSoltab(st) - axes = solutiontable.getValues()[1] - values_in = solutiontable.getValues()[0] - indexes = [list(axes['dir']).index(dir.decode('UTF-8')) for dir in directions] - axes['dir'] = [ns[0] for ns in new_sources] - dir_index = solutiontable.getAxesNames().index('dir') - shape = list(values_in.shape) - shape[dir_index] = len(directions) - values_new = zeros(shape) - - for idx_new, idx_old in enumerate(indexes): - if dir_index == 0: - values_new[idx_new, ...] += values_in[idx_old, ...] - elif dir_index == 1: - values_new[:, idx_new, ...] += values_in[:, idx_old, ...] - elif dir_index == 2: - values_new[:, :, idx_new, ...] += values_in[:, :, idx_old, ...] - elif dir_index == 3: - values_new[:, :, :, idx_new, ...] += values_in[:, :, :, idx_old, ...] - elif dir_index == 4: - values_new[:, :, :, :, idx_new, ...] += values_in[:, :, :, :, idx_old, ...] - - print('New number of sources {num}'.format(num=len(sources))) - print('Filtered output shape {shape}'.format(shape=values_new.shape)) - - weights = ones(values_new.shape) - solsetout.makeSoltab(remove_numbers(st), axesNames=list(axes.keys()), axesVals=list(axes.values()), - vals=values_new, - weights=weights) - - h5_in.close() - h5_out.close() - - -def parse_args(): - """ - Command line argument parser - - :return: parsed arguments - """ - - parser = ArgumentParser() - parser.add_argument('-f', '--fits', type=str, help='fitsfile name') - parser.add_argument('-ac', '--angular_cutoff', type=float, default=None, - help='angular distances higher than this value from the center will be excluded from the box selection') - parser.add_argument('-in', '--inside', type=str2bool, default=False, - help='keep directions inside the angular cutoff') - parser.add_argument('-h5out', '--h5_file_out', type=str, help='h5 output name') - parser.add_argument('-h5in', '--h5_file_in', type=str, help='h5 input name (to filter)') - return parser.parse_args() - - -def main(): - """Main function""" - - args = parse_args() - - # clean up files - if args.h5_file_out.split('/')[-1] in [f.split('/')[-1] for f in glob(args.h5_file_out)]: - os.system('rm -rf {file}'.format(file=args.h5_file_out)) - - # get header from fits file - hdu = fits.open(args.fits)[0] - header = WCS(hdu.header, naxis=2).to_header() - # get center of field - center = (degree_to_radian(header['CRVAL1']), degree_to_radian(header['CRVAL2'])) - - # return list of directions that have to be included - H = tables.open_file(args.h5_file_in) - sources = [] - directions = [] - for dir in H.root.sol000.source[:]: - print(angular_distance(center, dir[1])) - if args.inside and angular_distance(center, dir[1]) < args.angular_cutoff: - print('Keep {dir}'.format(dir=dir)) - directions.append(dir[0]) - sources.append(dir) - elif not args.inside and angular_distance(center, dir[1]) >= args.angular_cutoff: - print('Keep {dir}'.format(dir=dir)) - directions.append(dir[0]) - sources.append(dir) - else: - print('Remove {dir}'.format(dir=dir)) - H.close() - - create_new_soltab(args.h5_file_in, args.h5_file_out, directions, sources) - - -if __name__ == '__main__': - main() diff --git a/h5_helpers/plot_h5.py b/h5_helpers/hardcoded_plotting_stuff/plot_h5.py similarity index 100% rename from h5_helpers/plot_h5.py rename to h5_helpers/hardcoded_plotting_stuff/plot_h5.py diff --git a/h5_helpers/split_h5.py b/h5_helpers/split_h5.py index ec9d3b60..4d76d12f 100644 --- a/h5_helpers/split_h5.py +++ b/h5_helpers/split_h5.py @@ -1,3 +1,7 @@ +""" +Split out multi-dir h5 +""" + import tables import numpy as np import shutil diff --git a/h5_merger.py b/h5_merger.py index d827f5e3..696a1b61 100644 --- a/h5_merger.py +++ b/h5_merger.py @@ -12,25 +12,35 @@ __author__ = "Jurjen de Jong (jurjendejong@strw.leidenuniv.nl)" -from casacore import tables as ct -from collections import OrderedDict -from glob import glob -from losoto.h5parm import h5parm -from losoto.lib_operations import reorderAxes -from numpy import zeros, ones, round, unique, array_equal, append, where, isfinite, complex128, expand_dims, \ - pi, array, all, exp, angle, sort, sum, finfo, take, diff, equal, take, transpose, cumsum, insert, abs, asarray, newaxis, argmin, cos, sin, float32, memmap -import math +# Standard library imports import os -import psutil -import re -from scipy.interpolate import interp1d import sys -import tables +import re +import math import warnings +from glob import glob from argparse import ArgumentParser +from collections import OrderedDict +import psutil +import ast + +# Third-party imports +import numpy as np +from numpy import ( + zeros, ones, round, unique, array_equal, append, where, isfinite, complex128, + expand_dims, ndarray, pi, array, all, exp, angle, sort, sum, finfo, take, + diff, equal, transpose, cumsum, insert, abs, asarray, newaxis, argmin, + cos, sin, float32, memmap +) +from scipy.interpolate import interp1d +import tables from astropy.coordinates import SkyCoord from astropy import units as u -import ast +from casacore import tables as ct + +# Project-specific imports +from losoto.h5parm import h5parm +from losoto.lib_operations import reorderAxes warnings.filterwarnings('ignore') @@ -139,27 +149,34 @@ def overwrite_table(T, solset, table, values, title=None): :param title: title of new table """ + # Check if the file is opened with 'tables' try: T.root - except: - sys.exit('ERROR: Create table failed. Given table is not opened with the package "tables" (https://pypi.org/project/tables/).') + except AttributeError: + sys.exit('ERROR: Given table is not opened with the "tables" package (https://pypi.org/project/tables/).') - if 'sol' not in solset: - print('WARNING: Usual input have sol*** as solset name.') + # Warning if solset does not follow the usual naming convention + if not solset.startswith('sol'): + print('WARNING: Solution set name should start with "sol".') + # Get the solution set and remove the existing table ss = T.root._f_get_child(solset) - ss._f_get_child(table)._f_remove() + if hasattr(ss, table): + ss._f_get_child(table)._f_remove() + + # Handle specific cases for source and antenna tables if table == 'source': values = array(values, dtype=[('name', 'S128'), ('dir', ' 0: print('Ignore MS for time and freq axis, as --h5_time_freq is given.') print('Take the time and freq from the following h5 solution file:\n' + h5_time_freq) - T = tables.open_file(h5_time_freq) - self.ax_time = T.root.sol000.phase000.time[:] - self.ax_freq = T.root.sol000.phase000.freq[:] - T.close() + with tables.open_file(h5_time_freq) as T: + self.ax_time = T.root.sol000.phase000.time[:] + self.ax_freq = T.root.sol000.phase000.freq[:] # use ms files for available information elif len(self.ms) > 0: @@ -393,13 +407,12 @@ def __init__(self, h5_out, h5_tables=None, ms_files=None, h5_time_freq=None, con self.ax_time = array([]) self.ax_freq = array([]) for m in self.ms: - t = ct.taql('SELECT CHAN_FREQ, CHAN_WIDTH FROM ' + os.path.abspath(m) + '::SPECTRAL_WINDOW') - self.ax_freq = append(self.ax_freq, t.getcol('CHAN_FREQ')[0]) - t.close() + with ct.taql('SELECT CHAN_FREQ, CHAN_WIDTH FROM ' + os.path.abspath(m) + '::SPECTRAL_WINDOW') as t: + self.ax_freq = append(self.ax_freq, t.getcol('CHAN_FREQ')[0]) + + with ct.table(m) as t: + self.ax_time = append(self.ax_time, t.getcol('TIME')) - t = ct.table(m) - self.ax_time = append(self.ax_time, t.getcol('TIME')) - t.close() self.ax_time = array(sorted(unique(self.ax_time))) self.ax_freq = array(sorted(unique(self.ax_freq))) @@ -410,53 +423,50 @@ def __init__(self, h5_out, h5_tables=None, ms_files=None, h5_time_freq=None, con self.ax_time = array([]) self.ax_freq = array([]) for h5_name in self.h5_tables: - h5 = tables.open_file(h5_name) - for solset in h5.root._v_groups.keys(): - ss = h5.root._f_get_child(solset) - for soltab in ss._v_groups.keys(): - st = ss._f_get_child(soltab) - axes = make_utf8(st.val.attrs['AXES']).split(',') - if 'time' in axes: - time = st._f_get_child('time')[:] - self.ax_time = sort(unique(append(self.ax_time, time))) - else: - print('No time axes in ' + h5_name + '/' + solset + '/' + soltab) - if 'freq' in axes: - freq = st._f_get_child('freq')[:] - self.ax_freq = sort(unique(append(self.ax_freq, freq))) - else: - print('No freq axes in ' + h5_name + '/' + solset + '/' + soltab) - h5.close() + with tables.open_file(h5_name) as h5: + for solset in h5.root._v_groups.keys(): + ss = h5.root._f_get_child(solset) + for soltab in ss._v_groups.keys(): + st = ss._f_get_child(soltab) + axes = make_utf8(st.val.attrs['AXES']).split(',') + if 'time' in axes: + time = st._f_get_child('time')[:] + self.ax_time = sort(unique(append(self.ax_time, time))) + else: + print('No time axes in ' + h5_name + '/' + solset + '/' + soltab) + if 'freq' in axes: + freq = st._f_get_child('freq')[:] + self.ax_freq = sort(unique(append(self.ax_freq, freq))) + else: + print('No freq axes in ' + h5_name + '/' + solset + '/' + soltab) # get polarization output axis and check number of error and tec tables in merge list self.polarizations, polarizations = [], [] self.doublefulljones = False self.tecnum, self.errornum = 0, 0 # to average in case of multiple tables for n, h5_name in enumerate(self.h5_tables): - h5 = tables.open_file(h5_name) - if 'phase000' in h5.root.sol000._v_children.keys() \ - and 'pol' in make_utf8(h5.root.sol000.phase000.val.attrs["AXES"]).split(','): - polarizations = h5.root.sol000.phase000.pol[:] - - # having two fulljones solution files to merge --> we will use a matrix multiplication for this type of merge - if len(polarizations) == len(self.polarizations) == 4: - self.doublefulljones = True - self.fulljones_phases = OrderedDict() - self.fulljones_amplitudes = OrderedDict() - - # take largest polarization list/array - if len(polarizations) > len(self.polarizations): - if type(self.polarizations) == list: - self.polarizations = polarizations - else: - self.polarizations = polarizations.copy() + with tables.open_file(h5_name) as h5: + if 'phase000' in h5.root.sol000._v_children.keys() \ + and 'pol' in make_utf8(h5.root.sol000.phase000.val.attrs["AXES"]).split(','): + polarizations = h5.root.sol000.phase000.pol[:] + + # having two fulljones solution files to merge --> we will use a matrix multiplication for this type of merge + if len(polarizations) == len(self.polarizations) == 4: + self.doublefulljones = True + self.fulljones_phases = OrderedDict() + self.fulljones_amplitudes = OrderedDict() + + # take largest polarization list/array + if len(polarizations) > len(self.polarizations): + if type(self.polarizations) == list: + self.polarizations = polarizations + else: + self.polarizations = polarizations.copy() - if 'tec000' in h5.root.sol000._v_children.keys(): - self.tecnum += 1 - if 'error000' in h5.root.sol000._v_children.keys(): - self.errornum += 1 - - h5.close() + if 'tec000' in h5.root.sol000._v_children.keys(): + self.tecnum += 1 + if 'error000' in h5.root.sol000._v_children.keys(): + self.errornum += 1 # check if fulljones if len(self.polarizations) == 4: @@ -477,86 +487,69 @@ def __init__(self, h5_out, h5_tables=None, ms_files=None, h5_time_freq=None, con if not self.have_same_antennas and not no_antenna_crash: sys.exit('ERROR: Antenna tables are not the same') - # convert tec to phase? - self.convert_tec = convert_tec - self.has_converted_tec = False - self.merge_all_in_one = merge_all_in_one - if filtered_dir: - self.filtered_dir = filtered_dir - else: - self.filtered_dir = None - - # possible solution axis in order used for our merging script - self.solaxnames = ['pol', 'dir', 'ant', 'freq', 'time'] - # directions in an ordered dictionary self.directions = OrderedDict() if len(self.directions) > 1 and self.doublefulljones: sys.exit("ERROR: Merging not compatitable with multiple directions and double fuljones merge") # TODO: update - self.freq_concat = freq_concat - self.time_concat = time_concat + @property def have_same_antennas(self): """ - Compare antenna tables with each other. - These should be the same. - - :return: boolean if antennas are the same (True/False). - """ - - for h5_name1 in self.h5_tables: - H_ref = tables.open_file(h5_name1) - for solset1 in H_ref.root._v_groups.keys(): - ss1 = H_ref.root._f_get_child(solset1) - antennas_ref = ss1.antenna[:] - for soltab1 in ss1._v_groups.keys(): - if (len(antennas_ref['name']) != len(ss1._f_get_child(soltab1).ant[:])) or \ - (not all(antennas_ref['name'] == ss1._f_get_child(soltab1).ant[:])): - message = '\n'.join(['\nMismatch in antenna tables in ' + h5_name1, - 'Antennas from ' + '/'.join([solset1, 'antenna']), - str(antennas_ref['name']), - 'Antennas from ' + '/'.join([solset1, soltab1, 'ant']), - ss1._f_get_child(soltab1).ant[:]]) - print(message) - H_ref.close() - return False - for soltab2 in ss1._v_groups.keys(): - if (len(ss1._f_get_child(soltab1).ant[:]) != - len(ss1._f_get_child(soltab2).ant[:])) or \ - (not all(ss1._f_get_child(soltab1).ant[:] == - ss1._f_get_child(soltab2).ant[:])): - message = '\n'.join(['\nMismatch in antenna tables in ' + h5_name1, - 'Antennas from ' + '/'.join([solset1, soltab1, 'ant']), - ss1._f_get_child(soltab1).ant[:], - 'Antennas from ' + '/'.join([solset1, soltab2, 'ant']), - ss1._f_get_child(soltab2).ant[:]]) - print(message) - H_ref.close() - return False - for h5_name2 in self.h5_tables: - H = tables.open_file(h5_name2) - for solset2 in H.root._v_groups.keys(): - ss2 = H.root._f_get_child(solset2) - antennas = ss2.antenna[:] - if (len(antennas_ref['name']) != len(antennas['name'])) \ - or (not all(antennas_ref['name'] == antennas['name'])): - message = '\n'.join( - ['\nMismatch between antenna tables from ' + h5_name1 + ' and ' + h5_name2, - 'Antennas from ' + h5_name1 + ' and ', - str(antennas_ref['name']), - 'Antennas from ' + h5_name2 + ':', - str(antennas['name'])]) - print(message) - H.close() - H_ref.close() + Compare antenna tables across all H5 files. + All antenna tables should be the same. + + :return: Boolean indicating whether antennas are the same (True/False). + """ + + for i, h5_name1 in enumerate(self.h5_tables): + with tables.open_file(h5_name1, mode='r') as H_ref: + for solset1 in H_ref.root._v_groups.values(): + ss1 = solset1 + antennas_ref = ss1.antenna[:] + + # Check antennas within the same file + for soltab1 in ss1._v_groups.values(): + antennas1 = soltab1.ant[:] + if len(antennas_ref['name']) != len(antennas1) or not array_equal(antennas_ref['name'], antennas1): + self._print_antenna_mismatch(h5_name1, solset1._v_name, antennas_ref['name'], soltab1._v_name, antennas1) return False - H.close() - H_ref.close() + # Check antennas across other H5 files + for h5_name2 in self.h5_tables[i+1:]: + with tables.open_file(h5_name2, mode='r') as H: + for solset2 in H.root._v_groups.values(): + antennas2 = solset2.antenna[:] + if len(antennas_ref['name']) != len(antennas2['name']) or not array_equal(antennas_ref['name'], antennas2['name']): + self._print_antenna_mismatch_between_files(h5_name1, h5_name2, antennas_ref['name'], antennas2['name']) + return False return True + @staticmethod + def _print_antenna_mismatch(h5_name, solset_name, antennas_ref, soltab_name, antennas): + """ + Print a message for antenna mismatch within the same H5 file. + """ + message = '\n'.join([ + f'\nMismatch in antenna tables in {h5_name}', + f'Antennas from {solset_name}/antenna: {antennas_ref}', + f'Antennas from {solset_name}/{soltab_name}/ant: {antennas}' + ]) + print(message) + + @staticmethod + def _print_antenna_mismatch_between_files(self, h5_name1, h5_name2, antennas_ref, antennas): + """ + Print a message for antenna mismatch between two H5 files. + """ + message = '\n'.join([ + f'\nMismatch between antenna tables from {h5_name1} and {h5_name2}', + f'Antennas from {h5_name1}: {antennas_ref}', + f'Antennas from {h5_name2}: {antennas}' + ]) + print(message) + def concat(self, values, soltab, axes, ax_name): """ @@ -706,7 +699,6 @@ def _unpack_h5(self, st, solset, soltab): values = self._expand_poldim(values, len(self.polarizations), remove_numbers(soltab), False) self.axes_final.insert(0, 'pol') - # time interpolation if self.time_concat: values = self.concat(values, st.getType(), time_axes, 'time') @@ -728,10 +720,6 @@ def _sort_soltabs(self, soltabs): Sort solution tables. This is important to run the steps and add directions according to our algorithm. - %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% - Dont touch this part if you dont have to. ;-) - %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% - :param soltabs: solutions tables :return: sorted phase, tec, amplitude, rotation, error tables @@ -763,11 +751,6 @@ def get_allkeys(self): """ Get all solution sets, solutions tables, and ax names in lists. This returns an order that is optimized for this code. - - %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% - Dont touch this part if you dont have to. ;-) - %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% - """ self.all_soltabs, self.all_solsets, self.all_axes, self.ant = [], [], [], [] @@ -1304,15 +1287,8 @@ def reshape_dir(arr): Ayx = zeros(Axx.shape).astype(complex128) Ayy = take(output_table, indices=[3], axis=pol_ax) - # if sys.version_info.major > 2: - # print('0%', end='...') - # else: - # print('Start merging ...') - # Looping over tables to construct output for n, input_table in enumerate(tables): - # if sys.version_info.major > 2: - # print(str(int((n+1)*100/len(tables)))+'%', end='...') Bxx = reshape_dir(take(input_table, indices=[0], axis=pol_ax)) Bxy = reshape_dir(take(input_table, indices=[1], axis=pol_ax)) @@ -1526,7 +1502,6 @@ def create_new_dataset(self, solset, soltab): return self - def change_pol(self, single=False, nopol=False): """ Change polarization dimension @@ -1644,88 +1619,84 @@ def add_ms_antennas(self, keep_h5_interstations=None): if len(self.ms) == 0: sys.exit("ERROR: Measurement set needed to add antennas. Use --ms.") - t = ct.table(self.ms[0] + "::ANTENNA", ack=False) - try: - ms_antlist = [n.decode('utf8') for n in t.getcol('NAME')] - except AttributeError: - ms_antlist = t.getcol('NAME') - ms_antpos = t.getcol('POSITION') - ms_antennas = array([list(zip(*(ms_antlist, ms_antpos)))], dtype=[('name', 'S16'), ('position', ' ' + str(st.val.shape)) + valtype = str(st._f_get_child(axes).dtype) + if '16' in valtype: + atomtype = tables.Float16Atom() + elif '32' in valtype: + atomtype = tables.Float32Atom() + elif '64' in valtype: + atomtype = tables.Float64Atom() + else: + atomtype = tables.Float64Atom() - H.close() + st._f_get_child(axes)._f_remove() + H.create_array(st, axes, ms_values.astype(valtype), atom=atomtype) + st._f_get_child(axes).attrs['AXES'] = attrsaxes + print('Value shape after --> ' + str(st.val.shape)) return self @@ -1734,20 +1705,19 @@ def add_template(self): Make template for phase000 and/or amplitude000 if missing. """ - H = tables.open_file(self.h5name_out, 'r+') - soltabs = list(H.root.sol000._v_groups.keys()) - if 'amplitude000' not in soltabs: - if 'phase000' in soltabs: - self.amplitudes = ones(H.root.sol000.phase000.val.shape) - if len(self.polarizations) == 4: - self.amplitudes[..., 1] = 0. - self.amplitudes[..., 2] = 0. - self.create_new_dataset('sol000', 'amplitude') - if 'phase000' not in soltabs: - if 'amplitude000' in soltabs: - self.phases = zeros(H.root.sol000.amplitude000.val.shape) - self.create_new_dataset('sol000', 'phase') - H.close() + with tables.open_file(self.h5name_out, 'r+') as H: + soltabs = list(H.root.sol000._v_groups.keys()) + if 'amplitude000' not in soltabs: + if 'phase000' in soltabs: + self.amplitudes = ones(H.root.sol000.phase000.val.shape) + if len(self.polarizations) == 4: + self.amplitudes[..., 1] = 0. + self.amplitudes[..., 2] = 0. + self.create_new_dataset('sol000', 'amplitude') + if 'phase000' not in soltabs: + if 'amplitude000' in soltabs: + self.phases = zeros(H.root.sol000.amplitude000.val.shape) + self.create_new_dataset('sol000', 'phase') return self @@ -1760,121 +1730,118 @@ def add_weights(self): print("\nPropagating weights in:") - H = tables.open_file(self.h5name_out, 'r+') - for solset in H.root._v_groups.keys(): - ss = H.root._f_get_child(solset) - soltabs = list(ss._v_groups.keys()) + with tables.open_file(self.h5name_out, 'r+') as H: + for solset in H.root._v_groups.keys(): + ss = H.root._f_get_child(solset) + soltabs = list(ss._v_groups.keys()) - if self.has_converted_tec and not any(['tec' in i for i in soltabs]): - soltabs += ['tec000'] + if self.has_converted_tec and not any(['tec' in i for i in soltabs]): + soltabs += ['tec000'] - print(soltabs) - for n, soltab in enumerate(soltabs): - print(soltab + ', from:') + print(soltabs) + for n, soltab in enumerate(soltabs): + print(soltab + ', from:') - if 'tec' in soltab and \ - soltab not in list(ss._v_groups.keys()): - st = ss._f_get_child(soltab.replace('tec', 'phase')) - else: - st = ss._f_get_child(soltab) + if 'tec' in soltab and \ + soltab not in list(ss._v_groups.keys()): + st = ss._f_get_child(soltab.replace('tec', 'phase')) + else: + st = ss._f_get_child(soltab) - shape = st.val.shape + shape = st.val.shape + + if self.has_converted_tec and 'tec' in soltab: + weight_out = ss._f_get_child(soltab.replace('tec', 'phase')).weight[:] + else: + weight_out = ones(shape) + + axes_new = make_utf8(st.val.attrs["AXES"]).split(',') + for m, input_h5 in enumerate(self.h5_tables): + print(input_h5) + with tables.open_file(input_h5) as T: + if soltab not in list(T.root._f_get_child(solset)._v_groups.keys()): + print(soltab+' not in '+input_h5) + continue + st2 = T.root._f_get_child(solset)._f_get_child(soltab) + axes = make_utf8(st2.val.attrs["AXES"]).split(',') + weight = st2.weight[:] + weight = reorderAxes(weight, axes, [a for a in axes_new if a in axes]) + + # important to do the following in case the input tables are not all different directions + m = min(weight_out.shape[axes_new.index('dir')] - 1, m) + if self.merge_all_in_one: + m = 0 + + newvals = self._interp_along_axis(weight, st2.time[:], st.time[:], axes_new.index('time'), fill_value=1.).astype(float32) + newvals = self._interp_along_axis(newvals, st2.freq[:], st.freq[:], axes_new.index('freq'), fill_value=1.).astype(float32) + + if weight.shape[-2] != 1 and len(weight.shape) == 5: + print("Merge multi-dir weights") + if weight.shape[-2] != weight_out.shape[-2]: + sys.exit("ERROR: multi-dirs do not have equal shape.") + if 'pol' not in axes and 'pol' in axes_new: + newvals = expand_dims(newvals, axis=axes_new.index('pol')) + for n in range(weight_out.shape[-1]): + weight_out[..., n] *= newvals[..., -1] + + elif weight.ndim != weight_out.ndim: # not the same value shape + if 'pol' not in axes and 'pol' in axes_new: + newvals = expand_dims(newvals, axis=axes_new.index('pol')) + if newvals.shape[-1] != weight_out.shape[-1]: + temp = ones(shape) + for n in range(temp.shape[-1]): + temp[..., m, n] *= newvals[..., 0, 0] + newvals = temp + weight_out *= newvals + else: + sys.exit('ERROR: Upsampling of weights bug due to unexpected missing axes.\n axes from ' + + input_h5 + ': ' + str(axes) + '\n axes from ' + + self.h5name_out + ': ' + str(axes_new) + '.\n' + + debug_message) + elif set(axes) == set(axes_new) and 'pol' in axes: # same axes + pol_index = axes_new.index('pol') + if weight_out.shape[pol_index] == newvals.shape[pol_index]: # same pol numbers + weight_out[:, :, :, m, ...] *= newvals[:, :, :, 0, ...] + else: # not the same polarization axis + if newvals.shape[pol_index] != newvals.shape[-1]: + sys.exit('ERROR: Upsampling of weights bug due to polarization axis mismatch.\n' + + debug_message) + if newvals.shape[pol_index] == 1: # new values have only 1 pol axis + for i in range(weight_out.shape[pol_index]): + weight_out[:, :, :, m, i] *= newvals[:, :, :, 0, 0] + elif newvals.shape[pol_index] == 2 and weight_out.shape[pol_index] == 1: + for i in range(newvals.shape[pol_index]): + weight_out[:, :, :, m, 0] *= newvals[:, :, :, 0, i] + elif newvals.shape[pol_index] == 2 and weight_out.shape[pol_index] == 4: + weight_out[:, :, :, m, 0] *= newvals[:, :, :, 0, 0] + weight_out[:, :, :, m, 1] *= newvals[:, :, :, 0, 0] * newvals[:, :, :, 0, -1] + weight_out[:, :, :, m, 2] *= newvals[:, :, :, 0, 0] * newvals[:, :, :, 0, -1] + weight_out[:, :, :, m, -1] *= newvals[:, :, :, 0, -1] + else: + sys.exit('ERROR: Upsampling of weights bug due to unexpected polarization mismatch.\n' + + debug_message) + elif set(axes) == set(axes_new) and 'pol' not in axes: # same axes but no pol + dirind = axes_new.index('dir') + if weight_out.shape[dirind] != newvals.shape[dirind] and newvals.shape[dirind] == 1: + if len(weight_out.shape) == 4: + weight_out[:, :, m, :] *= newvals[:, :, 0, :] + elif len(weight_out.shape) == 5: + weight_out[:, :, :, m, :] *= newvals[:, :, :, 0, :] + elif weight_out.shape[dirind] != newvals.shape[dirind]: + sys.exit( + 'ERROR: Upsampling of weights because same direction exists multiple times in input h5 ' + '(verify and update directions or remove --propagate_flags)') + else: + weight_out *= newvals - if self.has_converted_tec and 'tec' in soltab: - weight_out = ss._f_get_child(soltab.replace('tec', 'phase')).weight[:] - else: - weight_out = ones(shape) - - axes_new = make_utf8(st.val.attrs["AXES"]).split(',') - for m, input_h5 in enumerate(self.h5_tables): - print(input_h5) - T = tables.open_file(input_h5) - if soltab not in list(T.root._f_get_child(solset)._v_groups.keys()): - print(soltab+' not in '+input_h5) - T.close() - continue - st2 = T.root._f_get_child(solset)._f_get_child(soltab) - axes = make_utf8(st2.val.attrs["AXES"]).split(',') - weight = st2.weight[:] - weight = reorderAxes(weight, axes, [a for a in axes_new if a in axes]) - - # important to do the following in case the input tables are not all different directions - m = min(weight_out.shape[axes_new.index('dir')] - 1, m) - if self.merge_all_in_one: - m = 0 - - newvals = self._interp_along_axis(weight, st2.time[:], st.time[:], axes_new.index('time'), fill_value=1.).astype(float32) - newvals = self._interp_along_axis(newvals, st2.freq[:], st.freq[:], axes_new.index('freq'), fill_value=1.).astype(float32) - - if weight.shape[-2] != 1 and len(weight.shape) == 5: - print("Merge multi-dir weights") - if weight.shape[-2] != weight_out.shape[-2]: - sys.exit("ERROR: multi-dirs do not have equal shape.") - if 'pol' not in axes and 'pol' in axes_new: - newvals = expand_dims(newvals, axis=axes_new.index('pol')) - for n in range(weight_out.shape[-1]): - weight_out[..., n] *= newvals[..., -1] - - elif weight.ndim != weight_out.ndim: # not the same value shape - if 'pol' not in axes and 'pol' in axes_new: - newvals = expand_dims(newvals, axis=axes_new.index('pol')) - if newvals.shape[-1] != weight_out.shape[-1]: - temp = ones(shape) - for n in range(temp.shape[-1]): - temp[..., m, n] *= newvals[..., 0, 0] - newvals = temp - weight_out *= newvals - else: - sys.exit('ERROR: Upsampling of weights bug due to unexpected missing axes.\n axes from ' - + input_h5 + ': ' + str(axes) + '\n axes from ' - + self.h5name_out + ': ' + str(axes_new) + '.\n' - + debug_message) - elif set(axes) == set(axes_new) and 'pol' in axes: # same axes - pol_index = axes_new.index('pol') - if weight_out.shape[pol_index] == newvals.shape[pol_index]: # same pol numbers - weight_out[:, :, :, m, ...] *= newvals[:, :, :, 0, ...] - else: # not the same polarization axis - if newvals.shape[pol_index] != newvals.shape[-1]: - sys.exit('ERROR: Upsampling of weights bug due to polarization axis mismatch.\n' - + debug_message) - if newvals.shape[pol_index] == 1: # new values have only 1 pol axis - for i in range(weight_out.shape[pol_index]): - weight_out[:, :, :, m, i] *= newvals[:, :, :, 0, 0] - elif newvals.shape[pol_index] == 2 and weight_out.shape[pol_index] == 1: - for i in range(newvals.shape[pol_index]): - weight_out[:, :, :, m, 0] *= newvals[:, :, :, 0, i] - elif newvals.shape[pol_index] == 2 and weight_out.shape[pol_index] == 4: - weight_out[:, :, :, m, 0] *= newvals[:, :, :, 0, 0] - weight_out[:, :, :, m, 1] *= newvals[:, :, :, 0, 0] * newvals[:, :, :, 0, -1] - weight_out[:, :, :, m, 2] *= newvals[:, :, :, 0, 0] * newvals[:, :, :, 0, -1] - weight_out[:, :, :, m, -1] *= newvals[:, :, :, 0, -1] else: - sys.exit('ERROR: Upsampling of weights bug due to unexpected polarization mismatch.\n' + sys.exit('ERROR: Upsampling of weights bug due to unexpected missing axes.\n axes from ' + + input_h5 + ': ' + str(axes) + '\n axes from ' + + self.h5name_out + ': ' + str(axes_new) + '.\n' + debug_message) - elif set(axes) == set(axes_new) and 'pol' not in axes: # same axes but no pol - dirind = axes_new.index('dir') - if weight_out.shape[dirind] != newvals.shape[dirind] and newvals.shape[dirind] == 1: - if len(weight_out.shape) == 4: - weight_out[:, :, m, :] *= newvals[:, :, 0, :] - elif len(weight_out.shape) == 5: - weight_out[:, :, :, m, :] *= newvals[:, :, :, 0, :] - elif weight_out.shape[dirind] != newvals.shape[dirind]: - sys.exit( - 'ERROR: Upsampling of weights because same direction exists multiple times in input h5 ' - '(verify and update directions or remove --propagate_flags)') - else: - weight_out *= newvals - - else: - sys.exit('ERROR: Upsampling of weights bug due to unexpected missing axes.\n axes from ' - + input_h5 + ': ' + str(axes) + '\n axes from ' - + self.h5name_out + ': ' + str(axes_new) + '.\n' - + debug_message) - T.close() - st.weight[:] = weight_out + st.weight[:] = weight_out - H.close() print('\n') return self @@ -1884,52 +1851,51 @@ def format_tables(self): Format direction tables (making sure dir and sources are the same). """ - H = tables.open_file(self.h5name_out, 'r+') - for solset in H.root._v_groups.keys(): - ss = H.root._f_get_child(solset) - sources = ss.source[:]['name'] - for soltab in ss._v_groups.keys(): - st = ss._f_get_child(soltab) - dirs = st._f_get_child('dir')[:] - if len(sources[:]) > len(dirs): - difference = list(set(sources) - set(dirs)) - - newdir = list(st._f_get_child('dir')[:]) + difference - st._f_get_child('dir')._f_remove() - H.create_array(st, 'dir', array(newdir).astype('|S5')) - - dir_ind = st.val.attrs['AXES'].decode('utf8').split(',').index('dir') - - for axes in ['val', 'weight']: - axs = st._f_get_child(axes).attrs['AXES'] - newval = st._f_get_child(axes)[:] - - shape = list(newval.shape) - for _ in difference: - shape[dir_ind] = 1 - if 'amplitude' in soltab or axes == 'weight': - newval = append(newval, ones(shape), - axis=dir_ind) + with tables.open_file(self.h5name_out, 'r+') as H: + for solset in H.root._v_groups.keys(): + ss = H.root._f_get_child(solset) + sources = ss.source[:]['name'] + for soltab in ss._v_groups.keys(): + st = ss._f_get_child(soltab) + dirs = st._f_get_child('dir')[:] + if len(sources[:]) > len(dirs): + difference = list(set(sources) - set(dirs)) + + newdir = list(st._f_get_child('dir')[:]) + difference + st._f_get_child('dir')._f_remove() + H.create_array(st, 'dir', array(newdir).astype('|S5')) + + dir_ind = st.val.attrs['AXES'].decode('utf8').split(',').index('dir') + + for axes in ['val', 'weight']: + axs = st._f_get_child(axes).attrs['AXES'] + newval = st._f_get_child(axes)[:] + + shape = list(newval.shape) + for _ in difference: + shape[dir_ind] = 1 + if 'amplitude' in soltab or axes == 'weight': + newval = append(newval, ones(shape), + axis=dir_ind) + else: + newval = append(newval, zeros(shape), + axis=dir_ind) + + valtype = str(st._f_get_child(axes).dtype) + if '16' in valtype: + atomtype = tables.Float16Atom() + elif '32' in valtype: + atomtype = tables.Float32Atom() + elif '64' in valtype: + atomtype = tables.Float64Atom() else: - newval = append(newval, zeros(shape), - axis=dir_ind) - - valtype = str(st._f_get_child(axes).dtype) - if '16' in valtype: - atomtype = tables.Float16Atom() - elif '32' in valtype: - atomtype = tables.Float32Atom() - elif '64' in valtype: - atomtype = tables.Float64Atom() - else: - atomtype = tables.Float64Atom() + atomtype = tables.Float64Atom() - st._f_get_child(axes)._f_remove() - H.create_array(st, axes, newval.astype(valtype), atom=atomtype) - st._f_get_child(axes).attrs['AXES'] = axs - elif len(sources[:]) < len(dirs): - output_check(self.h5name_out) - H.close() + st._f_get_child(axes)._f_remove() + H.create_array(st, axes, newval.astype(valtype), atom=atomtype) + st._f_get_child(axes).attrs['AXES'] = axs + elif len(sources[:]) < len(dirs): + output_check(self.h5name_out) return self @@ -1962,13 +1928,13 @@ def _change_solset(h5, solset_in, solset_out, delete=True, overwrite=True): 2) Delete solset_in if delete==True """ - H = tables.open_file(h5, 'r+') - H.root._f_get_child(solset_in)._f_copy(H.root, newname=solset_out, overwrite=overwrite, recursive=True) - print('Succesfully copied ' + solset_in + ' to ' + solset_out) - if delete: - H.root._f_get_child(solset_in)._f_remove(recursive=True) - print('Removed ' + solset_in + ' in output') - H.close() + with tables.open_file(h5, 'r+') as H: + H.root._f_get_child(solset_in)._f_copy(H.root, newname=solset_out, overwrite=overwrite, recursive=True) + print('Succesfully copied ' + solset_in + ' to ' + solset_out) + if delete: + H.root._f_get_child(solset_in)._f_remove(recursive=True) + print('Removed ' + solset_in + ' in output') + return @@ -1984,74 +1950,72 @@ def output_check(h5): print('\nChecking output...') - H = tables.open_file(h5) - - # check number of solset - assert len(list(H.root._v_groups.keys())) == 1, \ - 'More than 1 solset in ' + str(list(H.root._v_groups.keys())) + '. Only 1 is allowed for h5_merger.py.' - - for solset in H.root._v_groups.keys(): - - # check sol00.. name - assert 'sol' in solset, solset + ' is a wrong solset name, should be sol***' - ss = H.root._f_get_child(solset) - - # check antennas - antennas = ss.antenna - assert antennas.attrs.FIELD_0_NAME == 'name', 'No name in ' + '/'.join([solset, 'antenna']) - assert antennas.attrs.FIELD_1_NAME == 'position', 'No coordinate in ' + '/'.join([solset, 'antenna']) - - # check sources - sources = ss.source - assert sources.attrs.FIELD_0_NAME == 'name', 'No name in ' + '/'.join([solset, 'source']) - assert sources.attrs.FIELD_1_NAME == 'dir', 'No coordinate in ' + '/'.join([solset, 'source']) - - for soltab in ss._v_groups.keys(): - st = ss._f_get_child(soltab) - assert st.val.shape == st.weight.shape, \ - 'weight ' + str(st.weight.shape) + ' and values ' + str(st.val.shape) + ' do not have same shape' - - # check if pol and/or dir are missing - for pd in ['pol', 'dir']: - assert not (st.val.ndim == 5 and pd not in list(st._v_children.keys())), \ - '/'.join([solset, soltab, pd]) + ' is missing' - - # check if freq, time, and ant arrays are missing - for fta in ['freq', 'time', 'ant']: - assert fta in list(st._v_children.keys()), \ - '/'.join([solset, soltab, fta]) + ' is missing' - - # check if val and weight have AXES - for vw in ['val', 'weight']: - assert 'AXES' in st._f_get_child(vw).attrs._f_list("user"), \ - 'AXES missing in ' + '/'.join([solset, soltab, vw]) - - # check if dimensions of values match with length of arrays - for ax_index, ax in enumerate(st.val.attrs['AXES'].decode('utf8').split(',')): - assert st.val.shape[ax_index] == len(st._f_get_child(ax)[:]), \ - ax + ' length is not matching with dimension from val in ' + '/'.join([solset, soltab, ax]) - - # check if ant and antennas have equal sizes - if ax == 'ant': - assert len(antennas[:]) == len(st._f_get_child(ax)[:]), \ - '/'.join([solset, 'antenna']) + ' and ' + '/'.join( - [solset, soltab, ax]) + ' do not have same length' - - # check if dir and sources have equal sizes - if ax == 'dir': - assert len(sources[:]) == len(st._f_get_child(ax)[:]), \ - '/'.join([solset, 'source']) + ' and ' + '/'.join( - [solset, soltab, ax]) + ' do not have same length' - - # check if phase and amplitude have same shapes - for soltab1 in ss._v_groups.keys(): - if ('phase' in soltab or 'amplitude' in soltab) and ('phase' in soltab1 or 'amplitude' in soltab1): - st1 = ss._f_get_child(soltab1) - assert st.val.shape == st1.val.shape, \ - '/'.join([solset, soltab, 'val']) + ' shape: ' + str(st.weight.shape) + \ - '/'.join([solset, soltab1, 'val']) + ' shape: ' + str(st1.weight.shape) - - H.close() + with tables.open_file(h5) as H: + + # check number of solset + assert len(list(H.root._v_groups.keys())) == 1, \ + 'More than 1 solset in ' + str(list(H.root._v_groups.keys())) + '. Only 1 is allowed for h5_merger.py.' + + for solset in H.root._v_groups.keys(): + + # check sol00.. name + assert 'sol' in solset, solset + ' is a wrong solset name, should be sol***' + ss = H.root._f_get_child(solset) + + # check antennas + antennas = ss.antenna + assert antennas.attrs.FIELD_0_NAME == 'name', 'No name in ' + '/'.join([solset, 'antenna']) + assert antennas.attrs.FIELD_1_NAME == 'position', 'No coordinate in ' + '/'.join([solset, 'antenna']) + + # check sources + sources = ss.source + assert sources.attrs.FIELD_0_NAME == 'name', 'No name in ' + '/'.join([solset, 'source']) + assert sources.attrs.FIELD_1_NAME == 'dir', 'No coordinate in ' + '/'.join([solset, 'source']) + + for soltab in ss._v_groups.keys(): + st = ss._f_get_child(soltab) + assert st.val.shape == st.weight.shape, \ + 'weight ' + str(st.weight.shape) + ' and values ' + str(st.val.shape) + ' do not have same shape' + + # check if pol and/or dir are missing + for pd in ['pol', 'dir']: + assert not (st.val.ndim == 5 and pd not in list(st._v_children.keys())), \ + '/'.join([solset, soltab, pd]) + ' is missing' + + # check if freq, time, and ant arrays are missing + for fta in ['freq', 'time', 'ant']: + assert fta in list(st._v_children.keys()), \ + '/'.join([solset, soltab, fta]) + ' is missing' + + # check if val and weight have AXES + for vw in ['val', 'weight']: + assert 'AXES' in st._f_get_child(vw).attrs._f_list("user"), \ + 'AXES missing in ' + '/'.join([solset, soltab, vw]) + + # check if dimensions of values match with length of arrays + for ax_index, ax in enumerate(st.val.attrs['AXES'].decode('utf8').split(',')): + assert st.val.shape[ax_index] == len(st._f_get_child(ax)[:]), \ + ax + ' length is not matching with dimension from val in ' + '/'.join([solset, soltab, ax]) + + # check if ant and antennas have equal sizes + if ax == 'ant': + assert len(antennas[:]) == len(st._f_get_child(ax)[:]), \ + '/'.join([solset, 'antenna']) + ' and ' + '/'.join( + [solset, soltab, ax]) + ' do not have same length' + + # check if dir and sources have equal sizes + if ax == 'dir': + assert len(sources[:]) == len(st._f_get_child(ax)[:]), \ + '/'.join([solset, 'source']) + ' and ' + '/'.join( + [solset, soltab, ax]) + ' do not have same length' + + # check if phase and amplitude have same shapes + for soltab1 in ss._v_groups.keys(): + if ('phase' in soltab or 'amplitude' in soltab) and ('phase' in soltab1 or 'amplitude' in soltab1): + st1 = ss._f_get_child(soltab1) + assert st.val.shape == st1.val.shape, \ + '/'.join([solset, soltab, 'val']) + ' shape: ' + str(st.weight.shape) + \ + '/'.join([solset, soltab1, 'val']) + ' shape: ' + str(st1.weight.shape) print('Awesome! Output has all necessary information and correct dimensions.') @@ -2359,16 +2323,12 @@ def add_antenna_source_tables(self): Add antenna and source table to output file """ - T = tables.open_file(self.h5in_name) - H = tables.open_file(self.h5out_name, 'r+') - - for solset in T.root._v_groups.keys(): - ss = T.root._f_get_child(solset) - overwrite_table(H, solset, 'antenna', ss.antenna[:]) - overwrite_table(H, solset, 'source', ss.source[:]) - - T.close() - H.close() + with tables.open_file(self.h5in_name) as T: + with tables.open_file(self.h5out_name, 'r+') as H: + for solset in T.root._v_groups.keys(): + ss = T.root._f_get_child(solset) + overwrite_table(H, solset, 'antenna', ss.antenna[:]) + overwrite_table(H, solset, 'source', ss.source[:]) return @@ -2492,81 +2452,78 @@ def h5_check(h5): print( '%%%%%%%%%%%%%%%%%%%%%%%%%\nSOLUTION FILE CHECK START\n%%%%%%%%%%%%%%%%%%%%%%%%%\n\nSolution file name:\n' + h5) - H = tables.open_file(h5) - solsets = list(H.root._v_groups.keys()) - print('\nFollowing solution sets in ' + h5 + ':\n' + '\n'.join(solsets)) - for solset in solsets: - ss = H.root._f_get_child(solset) - soltabs = list(ss._v_groups.keys()) - print('\nFollowing solution tables in ' + solset + ':\n' + '\n'.join(soltabs)) - print('\nFollowing stations in ' + solset + ':\n' + ', '.join( - [make_utf8(a) for a in list(ss.antenna[:]['name'])])) - print('\nFollowing sources in ' + solset + ':\n' + '\n'.join( - [make_utf8(a['name']) + '-->' + str([a['dir'][0], a['dir'][1]]) for a in list(ss.source[:])])) - for soltab in soltabs: - st = ss._f_get_child(soltab) - for table in ['val', 'weight']: - axes = make_utf8(st._f_get_child(table).attrs["AXES"]) - print('\n' + '/'.join([solset, soltab, table]) + ' axes:\n' + - axes) - print('/'.join([solset, soltab, table]) + ' shape:\n' + - str(st._f_get_child(table).shape)) - if table == 'weight': - weights = st._f_get_child(table)[:] - element_sum = 1 - for w in weights.shape: element_sum *= w - flagged = round(sum(weights == 0.) / element_sum * 100, 2) - print('/'.join([solset, soltab, table]) + ' flagged:\n' + str(flagged) + '%') - if 'pol' in axes: - print('/'.join([solset, soltab, 'pol']) + ':\n' + ','.join( - [make_utf8(p) for p in list(st._f_get_child('pol')[:])])) - if 'time' in axes: - time = st._f_get_child('time')[:] - # print('/'.join([solset, soltab, 'time']) + ' start:\n' + str(time[0])) - # print('/'.join([solset, soltab, 'time']) + ' end:\n' + str(time[-1])) - if len(st._f_get_child('time')[:]) > 1: - print('/'.join([solset, soltab, 'time']) + ' time resolution:\n' + str( - diff(st._f_get_child('time')[:])[0])) - if 'freq' in axes: - freq = st._f_get_child('freq')[:] - # print('/'.join([solset, soltab, 'freq']) + ' start:\n' + str(freq[0])) - # print('/'.join([solset, soltab, 'freq']) + ' end:\n' + str(freq[-1])) - if len(st._f_get_child('freq')[:]) > 1: - print('/'.join([solset, soltab, 'freq']) + ' resolution:\n' + str( - diff(st._f_get_child('freq')[:])[0])) - - H.close() + with tables.open_file(h5) as H: + solsets = list(H.root._v_groups.keys()) + print('\nFollowing solution sets in ' + h5 + ':\n' + '\n'.join(solsets)) + for solset in solsets: + ss = H.root._f_get_child(solset) + soltabs = list(ss._v_groups.keys()) + print('\nFollowing solution tables in ' + solset + ':\n' + '\n'.join(soltabs)) + print('\nFollowing stations in ' + solset + ':\n' + ', '.join( + [make_utf8(a) for a in list(ss.antenna[:]['name'])])) + print('\nFollowing sources in ' + solset + ':\n' + '\n'.join( + [make_utf8(a['name']) + '-->' + str([a['dir'][0], a['dir'][1]]) for a in list(ss.source[:])])) + for soltab in soltabs: + st = ss._f_get_child(soltab) + for table in ['val', 'weight']: + axes = make_utf8(st._f_get_child(table).attrs["AXES"]) + print('\n' + '/'.join([solset, soltab, table]) + ' axes:\n' + + axes) + print('/'.join([solset, soltab, table]) + ' shape:\n' + + str(st._f_get_child(table).shape)) + if table == 'weight': + weights = st._f_get_child(table)[:] + element_sum = 1 + for w in weights.shape: element_sum *= w + flagged = round(sum(weights == 0.) / element_sum * 100, 2) + print('/'.join([solset, soltab, table]) + ' flagged:\n' + str(flagged) + '%') + if 'pol' in axes: + print('/'.join([solset, soltab, 'pol']) + ':\n' + ','.join( + [make_utf8(p) for p in list(st._f_get_child('pol')[:])])) + if 'time' in axes: + time = st._f_get_child('time')[:] + # print('/'.join([solset, soltab, 'time']) + ' start:\n' + str(time[0])) + # print('/'.join([solset, soltab, 'time']) + ' end:\n' + str(time[-1])) + if len(st._f_get_child('time')[:]) > 1: + print('/'.join([solset, soltab, 'time']) + ' time resolution:\n' + str( + diff(st._f_get_child('time')[:])[0])) + if 'freq' in axes: + freq = st._f_get_child('freq')[:] + # print('/'.join([solset, soltab, 'freq']) + ' start:\n' + str(freq[0])) + # print('/'.join([solset, soltab, 'freq']) + ' end:\n' + str(freq[-1])) + if len(st._f_get_child('freq')[:]) > 1: + print('/'.join([solset, soltab, 'freq']) + ' resolution:\n' + str( + diff(st._f_get_child('freq')[:])[0])) print('\n%%%%%%%%%%%%%%%%%%%%%%%%%%%%\nSOLUTION FILE CHECK FINISHED\n%%%%%%%%%%%%%%%%%%%%%%%%%%%%\n') return -def _checknan_input(h5): +def _checknan_input(h5_file): """ Check h5 on nan or 0 values - :param h5: h5parm file + :param h5_file: h5parm file """ - H = tables.open_file(h5) - - try: - amp = H.root.sol000.amplitude000.val[:] - print('Amplitude:') - print(amp[(~isfinite(amp)) | (amp == 0.)]) - del amp - except: - print('H.root.amplitude000 does not exist') - - try: - phase = H.root.sol000.phase000.val[:] - print('Phase:') - print(phase[(~isfinite(phase)) | (phase == 0.)]) - del phase - except: - print('H.root.sol000.phase000 does not exist') + with tables.open_file(h5_file, mode='r') as H: + # Check for amplitude + try: + amp = H.root.sol000.amplitude000.val[:] + nan_or_zero_amp = amp[(~isfinite(amp)) | (amp == 0.)] + print('Amplitude:') + print(nan_or_zero_amp) + except AttributeError: + print('H.root.sol000.amplitude000 does not exist.') - H.close() + # Check for phase + try: + phase = H.root.sol000.phase000.val[:] + nan_or_zero_phase = phase[(~isfinite(phase)) | (phase == 0.)] + print('Phase:') + print(nan_or_zero_phase) + except AttributeError: + print('H.root.sol000.phase000 does not exist.') return @@ -2596,66 +2553,65 @@ def move_source_in_sourcetable(h5, overwrite=False, dir_idx=None, dra_degrees=0, if not overwrite: os.system('cp ' + h5 + ' ' + h5.replace('.h5', '_upd.h5')) h5 = h5.replace('.h5', '_upd.h5') - H = tables.open_file(h5, 'r+') - sources = H.root.sol000.source[:] - sources[dir_idx][1][0] += _degree_to_radian(dra_degrees) - sources[dir_idx][1][1] += _degree_to_radian(ddec_degrees) - overwrite_table(H, 'sol000', 'source', sources) - H.close() + with tables.open_file(h5, 'r+') as H: + sources = H.root.sol000.source[:] + sources[dir_idx][1][0] += _degree_to_radian(dra_degrees) + sources[dir_idx][1][1] += _degree_to_radian(ddec_degrees) + overwrite_table(H, 'sol000', 'source', sources) return -def check_freq_overlap(h5_tables): +def check_overlap(h5_tables, check_type='freq'): """ - Verify if the frequency bands between the h5 tables overlap - :param h5_tables: h5parm input tables + Verify if the frequency bands or time slots between the HDF5 tables overlap. + + :param h5_tables: List of h5parm input tables. + :param check_type: Type of overlap to check ('freq' for frequency, 'time' for time). """ - for h51 in h5_tables: - H = tables.open_file(h51) - for h52 in h5_tables: - F = tables.open_file(h52) - try: - if h51 != h52 and F.root.sol000.phase000.freq[:].max() < H.root.sol000.phase000.freq[:].min(): - print("WARNING: frequency bands between "+h51+" and "+h52+" do not overlap, you might want to use " - "--freq_concat to merge over different frequency bands") - except: - try: - if h51 != h52 and F.root.sol000.amplitude000.freq[:].max() < H.root.sol000.amplitude000.freq[ - :].min(): - print("WARNING: frequency bands between "+h51+" and "+h52+" do not overlap, you might want to use " - "--freq_concat to merge over different frequency bands") - except: - pass - F.close() - H.close() - return + def _get_data(h5_file, data_type): + """ + Helper function to extract frequency or time data from an HDF5 file. -def check_time_overlap(h5_tables): - """ - Verify if the time slots between the h5 tables overlap + :param h5_file: Opened HDF5 file object. + :param data_type: 'freq' for frequency data, 'time' for time data. + :return: Data array, or None if it doesn't exist. + """ + try: + if data_type == 'freq': + return h5_file.root.sol000.phase000.freq[:] + elif data_type == 'time': + return h5_file.root.sol000.phase000.time[:] + except tables.NoSuchNodeError: + try: + if data_type == 'freq': + return h5_file.root.sol000.amplitude000.freq[:] + elif data_type == 'time': + return h5_file.root.sol000.amplitude000.time[:] + except tables.NoSuchNodeError: + return None - :param h5_tables: h5parm input tables - """ for h51 in h5_tables: - H = tables.open_file(h51) - for h52 in h5_tables: - F = tables.open_file(h52) - try: - if h51 != h52 and F.root.sol000.phase000.time[:].max() < H.root.sol000.phase000.time[:].min(): - print( - "WARNING: time slots between "+h51+" and "+h52+" do not overlap, might result in interpolation issues") - except: - try: - if h51 != h52 and F.root.sol000.amplitude000.time[:].max() < H.root.sol000.amplitude000.time[ - :].min(): - print( - "WARNING: time slots between "+h51+" and "+h52+" do not overlap, might result in interpolation issues") - except: - pass - F.close() - H.close() + with tables.open_file(h51, mode='r') as H: + data_h51 = _get_data(H, check_type) + + for h52 in h5_tables: + if h51 == h52: + continue + + with tables.open_file(h52, mode='r') as F: + data_h52 = _get_data(F, check_type) + + if data_h51 is not None and data_h52 is not None: + if data_h52.max() < data_h51.min(): + if check_type == 'freq': + print(f"WARNING: Frequency bands between {h51} and {h52} do not overlap. " + "You might want to use --freq_concat to merge over different frequency bands.") + elif check_type == 'time': + print(f"WARNING: Time slots between {h51} and {h52} do not overlap. " + "This might result in interpolation issues.") + return @@ -2682,7 +2638,7 @@ def repack(h5): def merge_h5(h5_out=None, h5_tables=None, ms_files=None, h5_time_freq=None, convert_tec=True, merge_all_in_one=False, lin2circ=False, circ2lin=False, add_directions=None, single_pol=None, no_pol=None, use_solset='sol000', filtered_dir=None, add_cs=None, add_ms_stations=None, check_output=None, freq_av=None, time_av=None, - propagate_flags=None, freq_concat=None, time_concat=None, no_antenna_crash=None, + propagate_flags=True, freq_concat=None, time_concat=None, no_antenna_crash=None, output_summary=None, min_distance=0.): """ Main function that uses the class MergeH5 to merge h5 tables. @@ -2783,9 +2739,9 @@ def merge_h5(h5_out=None, h5_tables=None, ms_files=None, h5_time_freq=None, conv # Check if frequencies from h5_tables overlap if not freq_concat: - check_freq_overlap(h5_tables) + check_overlap(h5_tables, 'freq') if not time_concat: - check_time_overlap(h5_tables) + check_overlap(h5_tables, 'time') if freq_concat and time_concat: sys.exit("ERROR: Cannot do both time and frequency concat (ask Jurjen for assistance to implement this feature)") @@ -2861,10 +2817,6 @@ def merge_h5(h5_out=None, h5_tables=None, ms_files=None, h5_time_freq=None, conv if sys.version_info.major == 2: merge.reorder_directions() - # Check if station weights are fully flagged in input and flag in output as well - # if check_flagged_station and not propagate_flags: - # merge.flag_stations() - # Check table source size merge.reduce_memory_source() @@ -2951,13 +2903,12 @@ def parse_input(): parser.add_argument('--add_directions', default=None, help='Add direction with amplitude 1 and phase 0 (example: --add_directions [0.73,0.12]).') parser.add_argument('--single_pol', action='store_true', default=None, help='Return only a single polarization axis if both polarizations are the same.') parser.add_argument('--no_pol', action='store_true', default=None, help='Remove polarization axis if both polarizations are the same.') - # parser.add_argument('--combine_h5', action='store_true', default=None, help='Merge H5 files with different time axis into 1.') parser.add_argument('--usesolset', type=str, default='sol000', help='Choose a solset to merge from your input solution files (only necessary if not sol000 is used).') parser.add_argument('--filter_directions', type=str, default=None, help='Filter out a list of indexed directions from your output solution file. Only lists allowed (example: --filter_directions [2, 3]).') parser.add_argument('--add_cs', action='store_true', default=None, help='Add core stations to antenna output from MS (needs --ms).') parser.add_argument('--add_ms_stations', action='store_true', default=None, help='Use only antenna stations from measurement set (needs --ms). Note that this is different from --add_cs, as it does not keep the international stations if these are not in the MS.') - # parser.add_argument('--no_stationflag_check', action='store_true', default=None, help='Do not flag complete station (for all directions) if entire station is flagged somewhere in input solution file.') - parser.add_argument('--propagate_flags', action='store_true', default=None, help='Interpolate weights and return in output file.') + parser.add_argument('--propagate_flags', action='store_true', default=None, help='(NOT USED ANYMORE) Interpolate weights and return in output file.') + parser.add_argument('--no_weight_prop', action='store_true', default=None, help='No interpolation of weights.') parser.add_argument('--no_antenna_crash', action='store_true', default=None, help='Do not check if antennas are in h5.') parser.add_argument('--output_summary', action='store_true', default=None, help='Give output summary.') parser.add_argument('--check_output', action='store_true', default=None, help='Check if the output has all the correct output information.') @@ -3025,6 +2976,10 @@ def main(): if args.merge_diff_freq: print('WARNING: --merge_diff_freq given, please use --freq_concat.') + if args.propagate_flags: + print('--propagate_flags does not have a function anymore, as this is now done by default. ' + 'If you want to turn this off, you can use --no_weight_prop') + merge_h5(h5_out=args.h5_out, h5_tables=args.h5_tables, ms_files=args.ms, @@ -3043,8 +2998,7 @@ def main(): check_output=args.check_output, time_av=args.time_av, freq_av=args.freq_av, - # check_flagged_station=not args.no_stationflag_check, - propagate_flags=args.propagate_flags, + propagate_flags=not args.no_weight_prop, freq_concat=args.merge_diff_freq or args.freq_concat, time_concat=args.time_concat, no_antenna_crash=args.no_antenna_crash, diff --git a/ms_helpers/__init__.py b/ms_helpers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/phasediff_scores/find_solint.py b/phasediff_scores/find_solint.py deleted file mode 100644 index e7ff966c..00000000 --- a/phasediff_scores/find_solint.py +++ /dev/null @@ -1,25 +0,0 @@ -import sys -from pathlib import Path - -# Add the parent directory to sys.path -sys.path.append(str(Path(__file__).resolve().parent.parent)) - -from source_selection.phasediff_output import GetSolint - -if __name__ == "__main__": - - # set std score, for which you want to find the solint - optimal_score = 2 - - # reference solution interval in minutes - ref_solint = 10 - - # solution file - h5 = '../P23872.h5' - - # get solution interval - S = GetSolint(h5, optimal_score, ref_solint) - solint = S.best_solint - - # OPTIONAL: plot fit - S.plot_C("T=" + str(round(solint, 2)) + " min") diff --git a/phasediff_scores/phasediff_selection/phasediff.sh b/phasediff_scores/phasediff_selection/phasediff.sh deleted file mode 100644 index ea411485..00000000 --- a/phasediff_scores/phasediff_selection/phasediff.sh +++ /dev/null @@ -1,81 +0,0 @@ -#!/bin/bash - -#input MS -MS=$1 - -#SCRIPT PATHS -FACETSELFCAL=$( python3 $HOME/parse_settings.py --facet_selfcal ) -LOFARFACETSELFCAL=$( python3 $HOME/parse_settings.py --lofar_facet_selfcal ) -LOFARHELPERS=$( python3 $HOME/parse_settings.py --lofar_helpers ) - -IFS='/' read -ra MSS <<< "$MS" -MSOUT=spd_${MSS[-1]} - -CURDIR=$PWD - -if [[ -f ${MS} ]] -then - echo "ERROR: ${MS} NEEDS TO HAVE ABSOLUTE PATHS" - exit 0 -fi - -#PRE-AVERAGE -DP3 \ -msin=${MS} \ -msin.orderms=False \ -msin.missingdata=True \ -msin.datacolumn=DATA \ -msout=${MSOUT} \ -msout.storagemanager=dysco \ -msout.writefullresflag=False \ -steps=[avg] \ -avg.type=averager \ -avg.freqresolution=390.56kHz \ -avg.timeresolution=60 - -#GET PHASEDIFF H5 -python $FACETSELFCAL \ --i phasediff \ ---forwidefield \ ---phaseupstations='core' \ ---skipbackup \ ---uvmin=20000 \ ---soltype-list="['scalarphasediff']" \ ---solint-list="['10min']" \ ---nchan-list="[6]" \ ---docircular \ ---uvminscalarphasediff=0 \ ---stop=1 \ ---soltypecycles-list="[0]" \ ---imsize=1600 \ ---skymodelpointsource=1.0 \ ---helperscriptspath=$LOFARFACETSELFCAL \ ---helperscriptspathh5merge=$LOFARHELPERS \ ---stopafterskysolve \ -${MSOUT} - -#BIG CLEAN UP -#mv *phasediff*.h5 ${CURDIR}/h5output -#rm -rf *.fits -#rm -rf *.p -#rm -rf tmpfile -##rm *.h5 -#rm -rf *.avg -#rm -rf *.phaseup -#rm -rf *.parset -#rm BLsmooth.py -#rm lin2circ.py -#rm lib_multiproc.py -#rm h5_merger.py -#rm polconv.py -#rm plot_tecandphase.py -#rm VLASS_dyn_summary.php -#rm vlass_search.py -#rm -rf __pycache__ -#rm *.log -#rm merged*.h5 -#rm *.scbackup -#rm *templatejones.h5 -#rm *.png - -cd ../ \ No newline at end of file diff --git a/phasediff_scores/phasediff_selection/phasediff_multi.sh b/phasediff_scores/phasediff_selection/phasediff_multi.sh deleted file mode 100644 index f4ede743..00000000 --- a/phasediff_scores/phasediff_selection/phasediff_multi.sh +++ /dev/null @@ -1,31 +0,0 @@ -#!/bin/bash - -#SCRIPT PATHS -FACETSELFCAL=$( python3 $HOME/parse_settings.py --facet_selfcal ) -LOFARFACETSELFCAL=$( python3 $HOME/parse_settings.py --lofar_facet_selfcal ) -LOFARHELPERS=$( python3 $HOME/parse_settings.py --lofar_helpers ) - -#TODO: NO PRE-AVERAGE - -#GET PHASEDIFF H5 -python $FACETSELFCAL \ --i phasediff \ ---forwidefield \ ---phaseupstations='core' \ ---skipbackup \ ---uvmin=20000 \ ---soltype-list="['scalarphasediff']" \ ---solint-list="['10min']" \ ---nchan-list="[6]" \ ---docircular \ ---uvminscalarphasediff=0 \ ---stop=1 \ ---soltypecycles-list="[0]" \ ---imsize=1600 \ ---skymodelpointsource=1.0 \ ---helperscriptspath=$LOFARFACETSELFCAL \ ---helperscriptspathh5merge=$LOFARHELPERS \ ---stopafterskysolve \ -*.ms - -cd ../ \ No newline at end of file diff --git a/phasediff_scores/phasediff_selection/source_selection.sh b/phasediff_scores/phasediff_selection/source_selection.sh deleted file mode 100644 index fabd7504..00000000 --- a/phasediff_scores/phasediff_selection/source_selection.sh +++ /dev/null @@ -1,33 +0,0 @@ -#!/bin/bash -#SBATCH -c 10 --constraint=amd - -#MSLIST TEXT FILE WITH PATH TO MS -MSLIST=$1 - -#SINGULARITY -BIND=$( python3 $HOME/parse_settings.py --BIND ) # SEE --> https://github.com/jurjen93/lofar_vlbi_helpers/blob/main/parse_settings.py -SIMG=$( python3 $HOME/parse_settings.py --SIMG ) - -#SCRIPT FOLDER -LOFAR_HELPERS=$( python3 $HOME/parse_settings.py --lofar_helpers ) -SCRIPT_DIR=${LOFAR_HELPERS}/phasediff_scores/phasediff_selection - - -#phasediff output folder -mkdir -p phasediff_h5s - -#RUN MS FROM MS LIST -while read -r MS; do - mkdir ${MS}_folder - cp -r ${MS} ${MS}_folder - cd ${MS}_folder - chmod 755 ${SCRIPT_DIR}/* - singularity exec -B $BIND $SIMG ${SCRIPT_DIR}/phasediff.sh ${MS} - mv scalarphasediff0*phaseup.h5 ../phasediff_h5s - mv ${MS} ../ - cd ../ - rm -rf ${MS}_folder -done <$MSLIST - -#RETURN SCORES -singularity exec -B $BIND $SIMG python ${LOFAR_HELPERS}/source_selection/phasediff_output.py --h5 phasediff_h5s/*.h5 diff --git a/phasediff_scores/phasediff_selection/source_selection_slurm.sh b/phasediff_scores/phasediff_selection/source_selection_slurm.sh deleted file mode 100644 index 224bca87..00000000 --- a/phasediff_scores/phasediff_selection/source_selection_slurm.sh +++ /dev/null @@ -1,30 +0,0 @@ -#!/bin/bash -#SBATCH -c 10 --constraint=amd --array=0-30 -t 1:00:00 - -#MSLIST TEXT FILE WITH PATH TO MS -PLIST=$1 - -#SINGULARITY -BIND=$( python3 $HOME/parse_settings.py --BIND ) # SEE --> https://github.com/jurjen93/lofar_vlbi_helpers/blob/main/parse_settings.py -SIMG=$( python3 $HOME/parse_settings.py --SIMG ) -#SCRIPT FOLDER -LOFAR_HELPERS=$( python3 $HOME/parse_settings.py --lofar_helpers ) -SCRIPT_DIR=$PWD - -#phasediff output folder -mkdir -p phasediff_h5s - -DIR=$( awk NR==${SLURM_ARRAY_TASK_ID} $PLIST ) - -echo "$DIR" -mkdir $DIR -cp -r *${DIR}*.ms $DIR -cd $DIR -chmod 755 ${SCRIPT_DIR}/* -singularity exec -B $BIND $SIMG ${SCRIPT_DIR}/phasediff_multi.sh -mv scalarphasediff0*phaseup.h5 ../phasediff_h5s -cd ../ -#rm -rf $DIR - -#RETURN SCORES -singularity exec -B $BIND $SIMG python ${SCRIPT_DIR}/phasediff_output.py --h5 phasediff_h5s/*.h5 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..73fb3ff4 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,23 @@ +[build-system] +requires = ["setuptools>=42", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "lofar_helpers" +version = "1.0.0" +description = "LOFAR helper scripts and tools" +dependencies = [] + +[project.scripts] +split_polygon_facets = "ds9_helpers.split_polygon_facets:main" +crop_nan_boundaries = "fits_helpers.crop_nan_boundaries:main" +cut_fits_with_region = "fits_helpers.cut_fits_with_region:main" +close_h5 = "h5_helpers.close_h5:main" +concat_with_dummies = "ms_helpers.concat_with_dummies:main" +remove_flagged_stations = "ms_helpers.remove_flagged_stations:main" +applycal = "ms_helpers.applycal:main" +subtract_with_wsclean = "subtract.subtract_with_wsclean:main" +h5_merger = "h5_merger:main" + +[tool.setuptools] +packages = ["ds9_helpers", "fits_helpers", "h5_helpers", "ms_helpers", "subtract"] diff --git a/source_selection/calibrator_selection/main.py b/source_selection/calibrator_selection/main.py deleted file mode 100644 index 3106503c..00000000 --- a/source_selection/calibrator_selection/main.py +++ /dev/null @@ -1,108 +0,0 @@ -import numpy as np -import matplotlib -import matplotlib.pyplot as plt - -matplotlib.use('TkAgg') - -SPD_INDEX = 0 -SOLINT_INDEX = 1 -RA_IDX = 2 -DEC_IDX = 3 - -MU = 0 -SIG = 0.3 - - -def gaussian(x, mu, sig): - return 1.0 / (np.sqrt(2.0 * np.pi) * sig) * np.exp(-np.power((x - mu) / sig, 2.0) / 2) - - -def apply_gauss(work_arr, source_point, mu=MU, sig=SIG): - temp_arr = np.copy(work_arr) - ra_diff = temp_arr[:, RA_IDX] - source_point[RA_IDX] - dec_diff = temp_arr[:, DEC_IDX] - source_point[DEC_IDX] - - distance = np.sqrt(ra_diff ** 2 + dec_diff ** 2) - - scale = gaussian(0, mu, sig) - - work_arr[:, SPD_INDEX] *= 1 - (gaussian(distance, mu, sig) / scale) - - -def plot_data(arr, selection=None, text=None): - graph, plot = plt.subplots(1, 1) - plt.scatter(arr[:, RA_IDX], arr[:, DEC_IDX], arr[:, SPD_INDEX], alpha=0.5) - if selection is not None: - plt.scatter(selection[:, RA_IDX], selection[:, DEC_IDX], selection[:, SPD_INDEX]) - - if text is not None: - for text, x, y in zip(text, selection[:, RA_IDX], selection[:, DEC_IDX]): - plt.annotate(text, (x, y)) - plot.invert_xaxis() - plt.xlabel("DEC") - plt.ylabel("RA") - plt.show() - - -def compare_selected_names(selection_names, output=False): - selection = sorted(selection_names) - good_selection = np.loadtxt("good-selection.txt", dtype=str) - - s_selection = set(selection) - s_good = set(good_selection) - - if output: - print(f"Selected: {len(s_selection)}. True: {len(s_good)}") - print(f"Same in {len(s_selection.intersection(s_good))}/{len(s_selection.union(s_good))} cases.") - - print("name: S T") - for name in sorted(list(s_selection.union(s_good))): - print(f"{name} {name in selection_names} {name in good_selection}") - - return len(s_selection.intersection(s_good)) / len(s_selection.union(s_good)) - - -def main(): - arr, names = load_data() - - selection_indices = select_sources(arr, threshold=0.1, sig=0.12) - selection_names = names[selection_indices] - compare_selected_names(selection_names, output=True) - - plot_data(arr, arr[selection_indices]) - - # Mask out all non-selected values - mask = np.zeros(len(arr), bool) - mask[selection_indices] = 1 - arr[:, SPD_INDEX] = mask * 10 - - plot_data(arr, arr[selection_indices], text=selection_names) - - -def select_sources(arr, threshold=0.1, sig=SIG): - selection_indices = [] - work_arr = np.copy(arr) - while np.max(work_arr[:, SPD_INDEX]) > 0: - idx = np.argmax(work_arr[:, SPD_INDEX]) - point = np.copy(work_arr[idx]) - if point[SPD_INDEX] < threshold: - break - - selection_indices.append(idx) - - apply_gauss(work_arr, point, sig=sig) - - # Uncomment to find the order of selection. - return selection_indices - - -def load_data(): - arr = np.loadtxt("phasediff_output.csv", delimiter=",", usecols=[1, 2, 3, 4], skiprows=1) - names = np.loadtxt("phasediff_output.csv", delimiter=",", usecols=[0], skiprows=1, dtype=str) - names = np.array([name.split("/", 1)[0] for name in names]) - arr[:, SPD_INDEX] = 1 / (arr[:, SPD_INDEX]) - return arr, names - - -if __name__ == "__main__": - main() diff --git a/source_selection/phasediff_output.py b/source_selection/phasediff_output.py deleted file mode 100644 index 14b3fcd8..00000000 --- a/source_selection/phasediff_output.py +++ /dev/null @@ -1,292 +0,0 @@ -""" -WARNING: THIS SCRIPT HAS BEEN MOVED TO https://github.com/rvweeren/lofar_facet_selfcal REPOSITORY - -This script is used to derive a S/N selection score by using an h5parm with scalarphasediff solutions from facetselfcal. -This is described in Section 3.3 of de Jong et al. (2024) -""" - -author__ = "Jurjen de Jong (jurjendejong@strw.leidenuniv.nl)" -__all__ = ['GetSolint'] - -import tables -import pandas as pd -import numpy as np -import matplotlib.pyplot as plt -from scipy.stats import circstd -from glob import glob -import csv -import sys -from argparse import ArgumentParser -from typing import Union - -try: # fix this - from selfcal_selection import parse_source_from_h5 - import scienceplots - plt.style.use(['science', 'ieee']) -except ImportError: - pass - - -def make_utf8(inp): - """ - Convert input to utf8 instead of bytes - - :param inp: string input - :return: input in utf-8 format - """ - - try: - inp = inp.decode('utf8') - return inp - except (UnicodeDecodeError, AttributeError): - return inp - - -def rad_to_degree(inp): - """ - Check if radians and convert to degree - - :param inp: two coordinates (RA, DEC) - :return: output in degrees - """ - - try: - if abs(inp[0]) < np.pi and abs(inp[1]) < np.pi: - return inp * 360 / 2 / np.pi % 360 - else: - return inp - except ValueError: # Sorry for the ugly code.. - if abs(inp[0][0]) < np.pi and abs(inp[0][1]) < np.pi: - return inp[0] * 360 / 2 / np.pi % 360 - else: - return inp[0] - - -class GetSolint: - def __init__(self, h5: str = None, optimal_score: float = 0.5, ref_solint: float = 10., station: str = None): - """ - Get a score based on the phase difference between XX and YY. This reflects the noise in the observation. - From this score we can determine an optimal solution interval, by fitting a wrapped normal distribution. - - See: - - https://en.wikipedia.org/wiki/Wrapped_normal_distribution - - https://en.wikipedia.org/wiki/Yamartino_method - - https://en.wikipedia.org/wiki/Directional_statistics - - :param h5: h5parm - :param optimal_score: score to fit solution interval - :param ref_solint: reference solution interval - :param station: station name - """ - - self.h5 = h5 - self.optimal_score = optimal_score - self.ref_solint = ref_solint - self.cstd = 0 - self.C = None - self.station = station - self.limit = np.pi - - def plot_C(self, title: str = None, saveas: str = None, extrapoints: Union[list, tuple] = None): - """ - Plot circstd score in function of solint for given C - """ - - # normal_sigmas = [n / 1000 for n in range(1, 10000)] - # values = [circstd(normal(0, n, 300)) for n in normal_sigmas] - # x = (self.C*limit**2) / (np.array(normal_sigmas) ** 2) / 2 - bestsolint = self.best_solint - # plt.plot(x, values, alpha=0.5) - solints = np.array(range(1, int(max(bestsolint * 200, self.ref_solint * 150)))) / 100 - plt.plot(solints, [self.theoretical_curve(float(t)) for t in solints], color='green') - plt.scatter([self.ref_solint], [self.cstd], c='blue', label='measurement', s=80, marker='x') - plt.scatter([bestsolint], [self.optimal_score], color='red', label='best solint', s=80, marker='x') - if extrapoints is not None: - plt.scatter(extrapoints[0], extrapoints[1], color='orange', label='other measurements', s=80, marker='x') - plt.xlim(0, max(bestsolint * 1.5, self.ref_solint * 1.5)) - # plt.xlim(0, 0.2) - plt.xlabel("solint (min)") - plt.ylabel("circstd score") - plt.legend(frameon=True, loc='upper right', fontsize=10) - if title is not None: - plt.title(title) - if saveas is not None: - plt.savefig(saveas) - else: - plt.show() - - return self - - def _circvar_to_normvar(self, circ_var: float = None): - """ - Convert circular variance to normal variance - - return: circular variance - """ - - if circ_var >= self.limit ** 2: - return 999 # replacement for infinity - else: - return -2 * np.log(1 - circ_var / (self.limit ** 2)) - - @property - def _get_C(self): - """ - Get constant defining the normal circular distribution - - :return: C - """ - - if self.cstd == 0: - self.get_phasediff_score(station=self.station) - normvar = self._circvar_to_normvar(self.cstd ** 2) - return normvar * self.ref_solint - - def get_phasediff_score(self, station: str = None): - """ - Calculate score for phasediff - - :return: circular standard deviation score - """ - - H = tables.open_file(self.h5) - - stations = [make_utf8(s) for s in list(H.root.sol000.antenna[:]['name'])] - - if station is None or station == '': - stations_idx = [stations.index(stion) for stion in stations if - ('RS' not in stion) & - ('ST' not in stion) & - ('CS' not in stion) & - ('DE' not in stion) & - ('PL' not in stion)] - else: - stations_idx = [stations.index(station)] - - axes = str(H.root.sol000.phase000.val.attrs["AXES"]).replace("b'", '').replace("'", '').split(',') - axes_idx = sorted({ax: axes.index(ax) for ax in axes}.items(), key=lambda x: x[1], reverse=True) - - phase = H.root.sol000.phase000.val[:] * H.root.sol000.phase000.weight[:] - H.close() - - phasemod = phase % (2 * np.pi) - - for ax in axes_idx: - if ax[0] == 'pol': # YX should be zero - phasemod = phasemod.take(indices=0, axis=ax[1]) - elif ax[0] == 'dir': # there should just be one direction - if phasemod.shape[ax[1]] == 1: - phasemod = phasemod.take(indices=0, axis=ax[1]) - else: - sys.exit('ERROR: This solution file should only contain one direction, but it has ' + - str(phasemod.shape[ax[1]]) + ' directions') - elif ax[0] == 'freq': # faraday corrected - if phasemod.shape[ax[1]] == 1: - print("WARNING: only 1 frequency --> Skip frequency diff for Faraday correction (score will be less accurate)") - else: - phasemod = np.diff(phasemod, axis=ax[1]) - elif ax[0] == 'ant': # take only international stations - phasemod = phasemod.take(indices=stations_idx, axis=ax[1]) - - phasemod[phasemod == 0] = np.nan - - self.cstd = circstd(phasemod, nan_policy='omit') - - return circstd(phasemod, nan_policy='omit') - - @property - def best_solint(self): - """ - Get optimal solution interval from phasediff, given C - - :return: value corresponding with increase solution interval - """ - - if self.cstd == 0: - self.get_phasediff_score(station=self.station) - self.C = self._get_C - optimal_cirvar = self.optimal_score ** 2 - return self.C / (self._circvar_to_normvar(optimal_cirvar)) - - def theoretical_curve(self, t): - """ - Theoretical curve based on circ statistics - :return: circular std - """ - - if self.C is None: - self.C = self._get_C - return self.limit * np.sqrt(1 - np.exp(-(self.C / (2 * t)))) - - -def parse_args(): - """ - Command line argument parser - - :return: parsed arguments - """ - - parser = ArgumentParser() - parser.add_argument('--h5', nargs='+', help='selfcal phasediff solutions', default=None) - parser.add_argument('--station', help='for one specific station', default=None) - parser.add_argument('--all_stations', action='store_true', help='for all stations specifically') - parser.add_argument('--make_plot', action='store_true', help='make phasediff plot') - parser.add_argument('--optimal_score', help='optimal score between 0 and pi', default=1.75, type=float) - return parser.parse_args() - -def main(): - - print('WARNING: THIS SCRIPT HAS BEEN MOVED TO https://github.com/rvweeren/lofar_facet_selfcal REPOSITORY\n' - 'This version has therefore not be maintained since September 2024') - - args = parse_args() - - # set std score, for which you want to find the solint - optimal_score = args.optimal_score - - # reference solution interval - ref_solint = 10 - - h5s = args.h5 - if len(h5s)==1 and ' ' in h5s[0]: - h5s = h5s[0].split(" ") - elif h5s is None: - h5s = glob("P*_phasediff/phasediff0*.h5") - - if args.station is not None: - station = args.station - else: - station = '' - - with open('phasediff_output.csv', 'w') as f: - writer = csv.writer(f) - writer.writerow(["source", "spd_score", "best_solint", 'RA', 'DEC']) - for h5 in h5s: - # try: - S = GetSolint(h5, optimal_score, ref_solint) - if args.all_stations: - H = tables.open_file(h5) - stations = [make_utf8(s) for s in list(H.root.sol000.antenna[:]['name'])] - H.close() - else: - stations = [station] - for station in stations: - std = S.get_phasediff_score(station=station) - solint = S.best_solint - H = tables.open_file(h5) - dir = rad_to_degree(H.root.sol000.source[:]['dir']) - writer.writerow([parse_source_from_h5(h5) + station, std, solint, dir[0], dir[1]]) - if args.make_plot: - S.plot_C("T=" + str(round(solint, 2)) + " min", saveas=h5 + station + '.png') - H.close() - # except: - # pass - - # sort output - df = pd.read_csv('phasediff_output.csv').sort_values(by='spd_score') - df.to_csv('phasediff_output.csv', index=False) - - - -if __name__ == '__main__': - main() diff --git a/source_selection/pre_selection/ddcal_pre_selection.cwl b/source_selection/pre_selection/ddcal_pre_selection.cwl deleted file mode 100644 index 9a6ca349..00000000 --- a/source_selection/pre_selection/ddcal_pre_selection.cwl +++ /dev/null @@ -1,64 +0,0 @@ -class: Workflow -cwlVersion: v1.2 -id: selfcal pre-selection -doc: | - This is a workflow to do a pre-selection for the LOFAR-VLBI pipeline direction-dependent calibrator selection - -requirements: - - class: SubworkflowFeatureRequirement - - class: MultipleInputFeatureRequirement - - class: ScatterFeatureRequirement - -inputs: - - id: msin - type: Directory[] - doc: The input MS. - - id: h5merger - type: Directory - doc: The h5merger directory. - - id: selfcal - type: Directory - doc: The selfcal directory. - -steps: - - id: dp3_prephasediff - label: Pre-averaging with DP3 - in: - - id: msin - source: msin - out: - - id: phasediff_ms - scatter: msin - scatterMethod: dotproduct - run: ./steps/dp3_prephasediff.cwl - - - id: get_phasediff - label: Get phase difference with facetselfcal - in: - - id: phasediff_ms - source: dp3_prephasediff/phasediff_ms - - id: h5merger - source: h5merger - - id: selfcal - source: selfcal - out: - - phasediff_h5out - scatterMethod: dotproduct - scatter: phasediff_ms - run: ./steps/get_phasediff.cwl - - - id: get_source_scores - label: Calculate phase difference score - in: - - id: phasediff_h5 - source: get_phasediff/phasediff_h5out - - id: selfcal - source: selfcal - out: - - phasediff_score_csv - run: ./steps/get_source_scores.cwl - -outputs: - - id: phasediff_score_csv - type: File - outputSource: get_source_scores/phasediff_score_csv diff --git a/source_selection/pre_selection/steps/dp3_prephasediff.cwl b/source_selection/pre_selection/steps/dp3_prephasediff.cwl deleted file mode 100644 index 5aab62e1..00000000 --- a/source_selection/pre_selection/steps/dp3_prephasediff.cwl +++ /dev/null @@ -1,56 +0,0 @@ -class: CommandLineTool -cwlVersion: v1.2 -id: pre_averaging_dp3 -label: DP3 Pre-averaging -doc: This tool prepares measurement set for pulling phasediff scores from facetselfcal - - -baseCommand: DP3 - -inputs: - - id: msin - type: Directory - doc: Input measurement set - inputBinding: - prefix: msin= - position: 0 - valueFrom: $(inputs.msin.path) - shellQuote: false - separate: false - -outputs: - - id: phasediff_ms - type: Directory - outputBinding: - glob: "*.phasediff.ms" - - id: logfile - type: File[] - outputBinding: - glob: dp3_prephasediff*.log - -arguments: - - steps=[avg] - - msin.datacolumn=DATA - - msout.storagemanager=dysco - - avg.type=averager - - avg.freqresolution=390.56kHz - - avg.timeresolution=60 - - msout=$(inputs.msin.path+".phasediff.ms") - -requirements: - - class: ShellCommandRequirement - - class: InlineJavascriptRequirement - - class: InitialWorkDirRequirement - listing: - - entry: $(inputs.msin) - writable: true - -hints: - - class: DockerRequirement - dockerPull: vlbi-cwl - - class: ResourceRequirement - coresMin: 6 - - -stdout: dp3_prephasediff.log -stderr: dp3_prephasediff_err.log diff --git a/source_selection/pre_selection/steps/filter_ms.cwl b/source_selection/pre_selection/steps/filter_ms.cwl deleted file mode 100644 index e6be0b47..00000000 --- a/source_selection/pre_selection/steps/filter_ms.cwl +++ /dev/null @@ -1,55 +0,0 @@ -class: CommandLineTool -cwlVersion: v1.2 -id: select_final -baseCommand: python3 - -inputs: - - id: msin - type: Directory[] - doc: The input MS. - - id: phasediff_score_csv - type: File - doc: The phase diff scores in a csv. - -requirements: - - class: ShellCommandRequirement - - class: InlineJavascriptRequirement - - class: InitialWorkDirRequirement - listing: - - entry: $(inputs.phasediff_score_csv) - writable: true - - entry: $(inputs.msin) - writable: true - - entryname: select_final.py - entry: | - import subprocess - import pandas as pd - import json - - inputs = json.loads(r"""$(inputs)""") - mslist = inputs['msin'] - - df = pd.read_csv(inputs['phasediff_score_csv']['location']) - selection = df[df['spd_score'] < 2.4]['source'].to_list() - - for s in selection: - for ms in mslist: - if s in ms['basename']: - subprocess.run(f"cp -r {ms['basename']} {ms['basename']}_selected.ms") - -hints: - - class: DockerRequirement - dockerPull: vlbi-cwl - -outputs: - - id: selected_ms - type: File - outputBinding: - glob: "*_selected.ms" - - id: logfile - type: File[] - outputBinding: - glob: select_final*.log - -stdout: select_final.log -stderr: select_final.log diff --git a/source_selection/pre_selection/steps/get_phasediff.cwl b/source_selection/pre_selection/steps/get_phasediff.cwl deleted file mode 100644 index bcbff8bb..00000000 --- a/source_selection/pre_selection/steps/get_phasediff.cwl +++ /dev/null @@ -1,73 +0,0 @@ -class: CommandLineTool -cwlVersion: v1.2 -id: get_phasediff -label: Polarization Phase Difference -doc: This tool prepares measurement set for pulling phasediff scores from facetselfcal - -baseCommand: python - -inputs: - - id: phasediff_ms - type: Directory - doc: Input measurement set - inputBinding: - position: 20 - valueFrom: $(inputs.phasediff_ms.path) - - id: h5merger - type: Directory - doc: The h5merger directory. - - id: selfcal - type: Directory - doc: The facetselfcal directory. - -outputs: - - id: phasediff_h5out - type: File - outputBinding: - glob: "scalarphasediff*.h5" - - id: logfile - type: File[] - outputBinding: - glob: phasediff*.log - -arguments: - - valueFrom: $( inputs.selfcal.path + '/facetselfcal.py' ) - - valueFrom: -i phasediff - - valueFrom: --forwidefield - - valueFrom: --phaseupstations=core - - valueFrom: --skipbackup - - valueFrom: --uvmin=20000 - - valueFrom: --soltype-list=['scalarphasediff'] - - valueFrom: --solint-list=['10min'] - - valueFrom: --nchan-list=[6] - - valueFrom: --docircular - - valueFrom: --uvminscalarphasediff=0 - - valueFrom: --stop=1 - - valueFrom: --soltypecycles-list=[0] - - valueFrom: --imsize=1600 - - valueFrom: --skymodelpointsource=1.0 - - valueFrom: --helperscriptspath=$(inputs.selfcal.path) - - valueFrom: --helperscriptspathh5merge=$(inputs.h5merger.path) - - valueFrom: --stopafterskysolve - - valueFrom: --phasediff_only - - -requirements: - - class: ShellCommandRequirement - - class: InlineJavascriptRequirement - - class: InitialWorkDirRequirement - listing: - - entry: $(inputs.phasediff_ms) - writable: true - - entry: $(inputs.h5merger) - - entry: $(inputs.selfcal) - -hints: - - class: DockerRequirement - dockerPull: vlbi-cwl - - class: ResourceRequirement - coresMin: 10 - - -stdout: phasediff.log -stderr: phasediff_err.log diff --git a/source_selection/pre_selection/steps/get_source_scores.cwl b/source_selection/pre_selection/steps/get_source_scores.cwl deleted file mode 100644 index a6c17064..00000000 --- a/source_selection/pre_selection/steps/get_source_scores.cwl +++ /dev/null @@ -1,53 +0,0 @@ -class: CommandLineTool -cwlVersion: v1.2 -id: get_source_scores -label: Source Scores -doc: This tool determines the phasediff scores from h5 scalarphasediff input from facetselfcal - -baseCommand: python - -inputs: - - id: phasediff_h5 - type: File[] - doc: Phasedifference h5 from facetselfcal - inputBinding: - prefix: "--h5" - position: 1 - itemSeparator: " " - separate: true - - id: selfcal - type: Directory - doc: The selfcal directory. - - -outputs: - - id: phasediff_score_csv - type: File - outputBinding: - glob: "*.csv" - - id: logfile - type: File[] - outputBinding: - glob: phasediff*.log - - -arguments: - - valueFrom: $( inputs.selfcal.path + '/source_selection/phasediff_output.py' ) - - -requirements: - - class: ShellCommandRequirement - - class: InlineJavascriptRequirement - - class: InitialWorkDirRequirement - listing: - - entry: $(inputs.phasediff_h5) - writable: true - - entry: $(inputs.selfcal) - -hints: - - class: DockerRequirement - dockerPull: vlbi-cwl - - -stdout: phasediff.log -stderr: phasediff_err.log diff --git a/source_selection/selfcal_selection.py b/source_selection/selfcal_selection.py deleted file mode 100644 index 16c4a457..00000000 --- a/source_selection/selfcal_selection.py +++ /dev/null @@ -1,695 +0,0 @@ -""" -WARNING: THIS SCRIPT HAS BEEN MOVED TO https://github.com/rvweeren/lofar_facet_selfcal REPOSITORY - -This script is used to select the best self-calibration cycle from facetselfcal.py see: -https://github.com/rvweeren/lofar_facet_selfcal/blob/main/facetselfcal.py -It will return a few plots and a csv with the statistics for each self-calibration cycle. -This is described in Section 3.3 of de Jong et al. (2024) - -You can run this script in the folder with your facetselfcal output on 1 source as -python selfcal_quality.py --fits *.fits --h5 merged*.h5 - -Alternatively (for testing), you can also run the script on multiple sources at the same time with -python selfcal_quality.py --parallel --root_folder_parallel -where '' is the path directed to where all calibrator source folders are located -""" - -__author__ = "Jurjen de Jong (jurjendejong@strw.leidenuniv.nl), Robert Jan Schlimbach (robert-jan.schlimbach@surf.nl)" - -import logging -import functools -import re -import csv -from pathlib import Path -from os import sched_getaffinity, system -import sys - -from joblib import Parallel, delayed -import tables -from scipy.stats import circstd, linregress -import numpy as np -import matplotlib.pyplot as plt -from argparse import ArgumentParser -from astropy.io import fits -import pandas as pd -from typing import Union - -logger = logging.getLogger(__file__) -logging.basicConfig(level=logging.INFO, format='%(message)s') - - -class SelfcalQuality: - def __init__(self, h5s: list, fitsim: list, station: str): - """ - Determine quality of selfcal from facetselfcal.py - - :param fitsim: fits images - :param h5s: h5parm solutions - :param folder: path to directory where selfcal ran - :param station: which stations to consider [dutch, remote, international, debug] - """ - - # merged selfcal h5parms - self.h5s = h5s - assert len(self.h5s) != 0, "No h5 files given" - - # selfcal images - self.fitsfiles = fitsim - assert len(self.fitsfiles) != 0, "No fits files found" - - # select all sources - sources = [] - for h5 in self.h5s: - filename = h5.split('/')[-1] - sources.append(parse_source_from_h5(filename)) - self.sources = set(sources) - assert len(self.sources) > 0, "No sources found" - - self.main_source = list(self.sources)[0].split("_")[-1] - - # for phase/amp evolution - self.station = station - - self.station_codes = ( - ('CS',) if self.station == 'dutch' else - ('RS',) if self.station == 'remote' else - ('CS', 'RS') if self.station == 'alldutch' else - ('IE', 'SE', 'PL', 'UK', 'LV', 'DE') # i.e.: if self.station == 'international' - ) - - # def image_entropy(self, fitsfile: str = None): - # """ - # Calculate entropy of image - # - # :param fitsfile: - # - # :return: image entropy value - # """ - # - # with fits.open(fitsfile) as f: - # image = f[0].data - # - # while image.ndim > 2: - # image = image[0] - # image = np.sqrt((image - self.minp) / (self.maxp - self.minp)) * 255 - # image = image.astype(np.uint8) - # val = entropy(image, disk(6)).sum() - # print(f"Entropy: {val}") - # return val - - def filter_stations(self, station_names): - """Generate indices of filtered stations""" - - if self.station == 'debug': - return list(range(len(station_names))) - - output_stations = [ - i for i, station_name in enumerate(station_names) - if any(station_code in station_name for station_code in self.station_codes) - ] - - return output_stations - - def print_station_names(self, h5): - """Print station names""" - - with tables.open_file(h5) as H: - station_names = H.root.sol000.antenna[:]['name'] - - stations = map(make_utf8, station_names) - - stations_used = ', '.join([ - station_name for station_name in stations - if any(station_code in station_name for station_code in self.station_codes) - ]) - logger.debug(f'Used the following stations: {stations_used}') - return self - - def get_solution_scores(self, h5_1: str, h5_2: str = None): - """ - Get solution scores - - :param h5_1: solution file 1 - :param h5_2: solution file 2 - - :return: phase_score --> circular std phase difference score - amp_score --> std amp difference score - """ - - def extract_data(tables_path): - with tables.open_file(tables_path) as f: - axes = make_utf8(f.root.sol000.phase000.val.attrs['AXES']).split(',') - pols = f.root.sol000.phase000.pol[:] if 'pol' in axes else ['XX'] - amps = ( - f.root.sol000.amplitude000.val[:] - if 'amplitude000' in list(f.root.sol000._v_groups.keys()) - else np.ones(f.root.sol000.phase000.val.shape) - ) - - return ( - [make_utf8(station) for station in f.root.sol000.antenna[:]['name']], - make_utf8(f.root.sol000.phase000.val.attrs['AXES']).split(','), - pols, - f.root.sol000.phase000.val[:], - f.root.sol000.phase000.weight[:], - amps, - ) - - def filter_params(station_indices, axes, *parameters): - return tuple( - np.take(param, station_indices, axes) - for param in parameters - ) - - def weighted_vals(vals, weights): - return np.nan_to_num(vals) * weights - - station_names1, axes1, phase_pols1, *params1 = extract_data(h5_1) - - antenna_selection = functools.partial(filter_params, self.filter_stations(station_names1), axes1.index('ant')) - phase_vals1, phase_weights1, amps1 = antenna_selection(*params1) - - prep_phase_score = weighted_vals(phase_vals1, phase_weights1) - prep_amp_score = weighted_vals(amps1, phase_weights1) - - if h5_2 is not None: - _, _, phase_pols2, *params2 = extract_data(h5_2) - phase_vals2, phase_weights2, amps2 = antenna_selection(*params2) - - min_length = min(len(phase_pols1), len(phase_pols2)) - assert 0 < min_length <= 4 - - indices = [0] if min_length == 1 else [0, -1] - - if 'pol' in axes1: - prep_phase_score, prep_amp_score, phase_vals2, phase_weights2, amps2 = filter_params( - indices, axes1.index('pol'), prep_phase_score, prep_amp_score, phase_vals2, phase_weights2, amps2 - ) - - # np.seterr(all='raise') - prep_phase_score = np.subtract(prep_phase_score, weighted_vals(phase_vals2, phase_weights2)) - prep_amp_score = np.divide( - prep_amp_score, - weighted_vals(amps2, phase_weights2), - out=np.zeros_like(prep_amp_score), - where=phase_weights2 != 0 - ) - - phase_score = circstd(prep_phase_score[prep_phase_score != 0], nan_policy='omit') - amp_score = np.std(prep_amp_score[prep_amp_score != 0]) - - return phase_score, amp_score - - def solution_stability(self): - """ - Get solution stability scores and make figure - - :return: bestcycle --> best cycle according to solutions - accept --> accept this selfcal - """ - - # loop over sources to get scores - for k, source in enumerate(self.sources): - logger.debug(source) - - sub_h5s = sorted([h5 for h5 in self.h5s if source in h5]) - - phase_scores = [] - amp_scores = [] - - for m, sub_h5 in enumerate(sub_h5s): - number = get_cycle_num(sub_h5) - - phase_score, amp_score = self.get_solution_scores(sub_h5, sub_h5s[m - 1] if number > 0 else None) - - phase_scores.append(phase_score) - amp_scores.append(amp_score) - - if k == 0: - total_phase_scores = [phase_scores] - total_amp_scores = [amp_scores] - - total_phase_scores = np.append(total_phase_scores, [phase_scores], axis=0) - total_amp_scores = np.append(total_amp_scores, [amp_scores], axis=0) - - # plot - finalphase, finalamp = (np.mean(score, axis=0) for score in (total_phase_scores, total_amp_scores)) - - return finalphase, finalamp - - @staticmethod - def solution_accept_reject(finalphase, finalamp): - - bestcycle = int(np.array(finalphase).argmin()) - - # selection based on slope - phase_decrease = linreg_slope(finalphase[:bestcycle+1]) - if not all(v == 0 for v in finalamp) and len(finalamp) >= 3: - # amplitude solves start typically later than phase solves - start_cycle = 0 - for a in finalamp: - if a==0: - start_cycle += 1 - if bestcycle-start_cycle > 3: - amp_decrease = linreg_slope([i for i in finalamp if i != 0][start_cycle:bestcycle+1]) - else: - amp_decrease = 0 - else: - amp_decrease = 0 - accept = ( - (phase_decrease <= 0 - or amp_decrease <= 0) - and bestcycle >= 1 - and finalphase[bestcycle] < 1 - and finalphase[0] > finalphase[bestcycle] - ) - return bestcycle, accept - - @staticmethod - def image_stability(rmss, minmaxs): - """ - Determine image stability - - :param minmaxs: absolute values of min/max for each self-cal cycle - :param rmss: rms (noise) for each self-cal cycle - - :return: bestcycle --> best solution cycle - accept --> accept this selfcal - """ - - # metric scores - combined_metric = min_max_norm(rmss) * min_max_norm(minmaxs) - - # best cycle - bestcycle = select_cycle(combined_metric) - - # getting slopes for selection - rms_slope, minmax_slope = linregress(list(range(len(rmss[:bestcycle+1]))), rmss[:bestcycle+1]).slope, linregress( - list(range(len(rmss[:bestcycle+1]))), - np.array( - minmaxs[:bestcycle+1])).slope - - # acceptance criteria - accept = ((rms_slope <= 0 - or minmax_slope <= 0) - and rmss[0] > rmss[bestcycle] - and minmaxs[0] > minmaxs[bestcycle] - and bestcycle >= 1) - - return bestcycle, accept - - def peak_flux_constraint(self): - """ - Validate if the peak flux is larger than 100 times the local rms - """ - return get_peakflux(self.fitsfiles[0])/get_rms(self.fitsfiles[0]) > 100 - - -def parse_source_from_h5(h5): - """ - Parse sensible output names - """ - h5 = h5.split("/")[-1] - if 'ILTJ' in h5: - matches = re.findall(r'ILTJ\d+\..\d+\+\d+.\d+_L\d+', h5) - if len(matches)==0: - matches = re.findall(r'ILTJ\d+\..\d+\+\d+.\d+', h5) - if len(matches)==0: - print("WARNING: Difficulty with parsing the source name form " + h5) - output = (re.sub('(\D)\d{3}\_', '', h5). - replace("merged_", ""). - replace('addCS_', ''). - replace('selfcalcyl', ''). - replace('selfcalcyle', ''). - replace('.ms', ''). - replace('.copy', ''). - replace('.phaseup', ''). - replace('.h5', ''). - replace('.dp3', ''). - replace('-concat', ''). - replace('.phasediff',''). - replace('_uv',''). - replace('scalarphasediff0_sky','')) - print('Parsed into ' + h5) - return output - output = matches[0] - elif 'selfcalcyle' in h5: - matches = re.findall(r'selfcalcyle\d+_(.*?)\.', h5) - output = matches[0] - else: - print("WARNING: Difficulty with parsing the source name form "+h5) - output = (re.sub('(\D)\d{3}\_', '', h5). - replace("merged_", ""). - replace('addCS_', ''). - replace('selfcalcyl', ''). - replace('selfcalcyle', ''). - replace('.ms', ''). - replace('.copy', ''). - replace('.phaseup', ''). - replace('.h5', ''). - replace('.dp3', ''). - replace('-concat', ''). - replace('.phasediff', ''). - replace('_uv', ''). - replace('scalarphasediff0_sky', '')) - print('Parsed into '+h5) - - return output - - -def min_max_norm(lst): - """Normalize list values between 0 and 1""" - - # find the minimum and maximum - min_value = min(lst) - max_value = max(lst) - - # normalize - normalized_floats = [(x - min_value) / (max_value - min_value) for x in lst] - - return np.array(normalized_floats) - - -def linreg_slope(values=None): - """ - Fit linear regression and return slope - - :param values: Values - - :return: linear regression slope - """ - - return linregress(list(range(len(values))), values).slope - - -def get_minmax(inp: Union[str, np.ndarray]): - """ - Get min/max value - - :param inp: fits file name or numpy array - - :return: minmax --> pixel min/max value - """ - if isinstance(inp, str): - with fits.open(inp) as hdul: - data = hdul[0].data - else: - data = inp - - minmax = np.abs(data.min() / data.max()) - - logger.debug(f"min/max: {minmax}") - return minmax - -def get_peakflux(inp: Union[str, np.ndarray]): - """ - Get min/max value - - :param inp: fits file name or numpy array - - :return: minmax --> pixel min/max value - """ - if isinstance(inp, str): - with fits.open(inp) as hdul: - data = hdul[0].data - else: - data = inp - - mx = data.max() - - logger.debug(f"Peak flux: {mx}") - return mx - - -def select_cycle(cycles=None): - """ - Select best cycle - - :param cycles: rms or minmax cycles - - :return: best cycle - """ - - b, best_cycle = 0, 0 - for n, c in enumerate(cycles[1:]): - if c > cycles[n - 1]: - b += 1 - else: - b = 0 - best_cycle = n + 1 - if b == 2: - break - return best_cycle - - -def get_rms(inp: Union[str, np.ndarray], maskSup: float = 1e-7): - """ - find the rms of an array, from Cycil Tasse/kMS - - :param inp: fits file name or numpy array - :param maskSup: mask threshold - - :return: rms --> rms of image - """ - - if isinstance(inp, str): - with fits.open(inp) as hdul: - data = hdul[0].data - else: - data = inp - - mIn = np.ndarray.flatten(data) - m = mIn[np.abs(mIn) > maskSup] - rmsold = np.std(m) - diff = 1e-1 - cut = 3. - med = np.median(m) - - for i in range(10): - ind = np.where(np.abs(m - med) < rmsold * cut)[0] - rms = np.std(m[ind]) - if np.abs((rms - rmsold) / rmsold) < diff: - break - rmsold = rms - - logger.debug(f'rms: {rms}') - - return rms # jy/beam - - -def get_cycle_num(fitsfile: str = None) -> int: - """ - Parse cycle number - - :param fitsfile: fits file name - """ - - cycle_num = int(float(re.findall(r"selfcalcyle(\d+)", fitsfile.split('/')[-1])[0])) - assert cycle_num >= 0 - return cycle_num - - -def make_utf8(inp=None): - """ - Convert input to utf8 instead of bytes - - :param inp: string input - """ - - try: - inp = inp.decode('utf8') - return inp - except (UnicodeDecodeError, AttributeError): - return inp - - -def make_figure(vals1=None, vals2=None, label1=None, label2=None, plotname=None, bestcycle=None): - """ - Make figure (with optionally two axis) - - :param vals1: values 1 - :param vals2: values 2 - :param label1: label corresponding to values 1 - :param label2: label corresponding to values 2 - :param plotname: plot name - :param bestcycle: plot best cycle - """ - - plt.style.use('ggplot') - - fig, ax1 = plt.subplots() - - color = 'red' - ax1.set_xlabel('cycle') - ax1.set_ylabel(label1, color="tab:"+color) - ax1.plot([i for i in range(len(vals1))], vals1, color='dark'+color, linewidth=2, marker='s', - markerfacecolor="tab:"+color, markersize=3, alpha=0.7, dash_capstyle='round', dash_joinstyle='round') - ax1.tick_params(axis='y', labelcolor="tab:"+color) - ax1.grid(False) - ax1.plot([bestcycle, bestcycle], [0, max(vals1)], linestyle='--', color='black') - ax1.set_ylim(0, max(vals1)) - - if vals2 is not None: - - ax2 = ax1.twinx() - - color = 'blue' - ax2.set_ylabel(label2, color="tab:"+color) - ax2.plot([i for i, v in enumerate(vals2) if v!=0], [v for v in vals2 if v!=0], color='dark'+color, linewidth=2, - marker='s', markerfacecolor="tab:"+color, markersize=3, alpha=0.7, dash_capstyle='round', - dash_joinstyle='round') - ax2.tick_params(axis='y', labelcolor="tab:"+color) - ax2.grid(False) - ax2.set_ylim(0, max(vals2)) - - fig.tight_layout() - - plt.savefig(plotname, dpi=150) - - -def parse_args(): - """ - Command line argument parser - - :return: parsed arguments - """ - - parser = ArgumentParser(description='Determine selfcal quality') - parser.add_argument('--fits', nargs='+', help='selfcal fits images') - parser.add_argument('--h5', nargs='+', help='h5 solutions') - parser.add_argument('--station', type=str, help='', default='international', - choices=['dutch', 'remote', 'alldutch', 'international', 'debug']) - parser.add_argument('--parallel', action='store_true', help='run parallel over multiple sources (for testing)') - parser.add_argument('--root_folder_parallel', type=str, help='root folder (if parallel is on), ' - 'which is the path to where all calibrator source folders are located') - return parser.parse_args() - - -def main(h5s: list = None, fitsfiles: list = None, station: str = 'international'): - """ - Main function - - Input: - - List of h5 files - - List of fits files - - Station type - - Returns: - - Source name - - Accept source (total) - - Best cycle (total) - - Accept according to solutions - - Best cycle according to solutions - - Accept according to images - - Best cycle according to images - - Best h5parm - """ - sq = SelfcalQuality(h5s, fitsfiles, station) - - assert len(sq.h5s) > 1 and len(sq.fitsfiles) > 1, "Need more than 1 h5 or fits file" - - finalphase, finalamp = sq.solution_stability() - bestcycle_solutions, accept_solutions = sq.solution_accept_reject(finalphase, finalamp) - - rmss = [get_rms(fts) * 1000 for fts in sq.fitsfiles] - minmaxs = [get_minmax(fts) for fts in sq.fitsfiles] - - bestcycle_image, accept_image_stability = sq.image_stability(rmss, minmaxs) - accept_peak = sq.peak_flux_constraint() - - best_cycle = int(round((bestcycle_solutions + bestcycle_image - 1)//2, 0)) - # final accept - accept = (accept_image_stability - and accept_peak - and accept_solutions - and (rmss[best_cycle] < rmss[0] or minmaxs[best_cycle] < minmaxs[0]) - ) - - logger.info( - f"{sq.main_source} | " - f"Best cycle according to images: {bestcycle_image}, accept image: {accept_image_stability}. " - f"Best cycle according to solutions: {bestcycle_solutions}, accept solution: {accept_solutions}. " - ) - - logger.info( - f"{sq.main_source} | accept: {accept}, best solutions: {sq.h5s[best_cycle]}" - ) - - fname = f'./selection_output/selfcal_performance_{sq.main_source}.csv' - system(f'mkdir -p ./selection_output') - with open(fname, 'w') as textfile: - # output csv - csv_writer = csv.writer(textfile) - csv_writer.writerow(['solutions', 'dirty'] + [str(i) for i in range(len(sq.fitsfiles))]) - - # best cycle based on phase solution stability - csv_writer.writerow(['phase', np.nan] + list(finalphase)) - csv_writer.writerow(['amp', np.nan] + list(finalamp)) - - csv_writer.writerow(['min/max'] + minmaxs + [np.nan]) - csv_writer.writerow(['rms'] + rmss + [np.nan]) - - make_figure(finalphase, finalamp, 'Phase stability', 'Amplitude stability', f'./selection_output/solution_stability_{sq.main_source}.png', best_cycle) - make_figure(rmss, minmaxs, 'RMS (mJy/beam)', '$|min/max|$', f'./selection_output/image_stability_{sq.main_source}.png', best_cycle) - - df = pd.read_csv(fname).set_index('solutions').T - df.to_csv(fname, index=False) - - return sq.main_source, accept, best_cycle, accept_solutions, bestcycle_solutions, accept_image_stability, bestcycle_image, sq.h5s[best_cycle] - - -def calc_all_scores(sources_root, stations='international'): - """ - For parallel calculation (for testing purposes only) - """ - - def get_solutions(item): - item = Path(item) - if not (item.is_dir() and any(file.suffix == '.h5' for file in item.iterdir())): - return None - star_folder, star_name = item, item.name - - try: - return main(list(map(str, sorted(star_folder.glob('merged*.h5')))), - list(map(str, sorted(star_folder.glob('*MFS-*image.fits')))), stations) - except Exception as e: - logger.warning(f"skipping {star_folder} due to {e}") - return star_name, None, None, None, None, None, None - - all_files = [p for p in Path(sources_root).iterdir() if p.is_dir()] - - results = Parallel(n_jobs=len(sched_getaffinity(0)))(delayed(get_solutions)(f) for f in all_files) - - results = filter(None, results) - - fname = f'./selection_output/selfcal_performance.csv' - system(f'mkdir -p ./selection_output') - with open(fname, 'w') as textfile: - csv_writer = csv.writer(textfile) - csv_writer.writerow( - ['source', 'accept', 'bestcycle', 'accept_solutions', 'bestcycle_solutions', 'accept_images', 'bestcycle_images', 'best_h5'] - ) - - for res in results: - # output csv - csv_writer.writerow(res) - - -if __name__ == '__main__': - - print('WARNING: THIS SCRIPT HAS BEEN MOVED TO https://github.com/rvweeren/lofar_facet_selfcal REPOSITORY\n' - 'This version has therefore not be maintained since September 2024') - - args = parse_args() - - output_folder='./selection_output' - system(f'mkdir -p {output_folder}') - - if args.parallel: - print(f"Running parallel in {args.root_folder_parallel}") - if args.root_folder_parallel is not None: - calc_all_scores(args.root_folder_parallel) - else: - sys.exit("ERROR: if parallel, you need to specify --root_folder_parallel") - else: - main(args.h5, args.fits, args.station) diff --git a/subtract/__init__.py b/subtract/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/toil_examples/source_selection.sh b/toil_examples/source_selection.sh deleted file mode 100644 index 530681e7..00000000 --- a/toil_examples/source_selection.sh +++ /dev/null @@ -1,79 +0,0 @@ -#!/bin/bash - -#NOTE: works only with TOIL>6.0.0 - -#### UPDATE THESE #### - -export TOIL_SLURM_ARGS="--export=ALL --job-name phasediff -p normal -t 4:00:00" -SING_BIND="/project,/project/lofarvwf/Software,/project/lofarvwf/Share,/project/lofarvwf/Public,/home/lofarvwf-jdejong" -VLBI_SCRIPTS="/project/lofarvwf/Software/vlbi/scripts/" - -CWL_WORKFLOW=/project/lofarvwf/Software/lofar_helpers/source_selection/pre_selection/ddcal_pre_selection.cwl -YAML=input.yaml -VENV=/home/lofarvwf-jdejong/venv - -###################### - -# set up singularity -SIMG=vlbi-cwl.sif -mkdir -p singularity -wget https://lofar-webdav.grid.sara.nl/software/shub_mirror/tikk3r/lofar-grid-hpccloud/amd/flocs_v4.5.0_znver2_znver2_aocl4_cuda.sif -O singularity/$SIMG -mkdir -p singularity/pull -cp singularity/$SIMG singularity/pull/$SIMG - -CONTAINERSTR=$(singularity --version) -if [[ "$CONTAINERSTR" == *"apptainer"* ]]; then - export APPTAINER_CACHEDIR=$PWD/singularity - export APPTAINER_TMPDIR=$APPTAINER_CACHEDIR/tmp - export APPTAINER_PULLDIR=$APPTAINER_CACHEDIR/pull - export APPTAINER_BIND=$SING_BIND - export APPTAINERENV_PYTHONPATH='${VLBI_SCRIPTS}:$PYTHONPATH' -else - export SINGULARITY_CACHEDIR=$PWD/singularity - export SINGULARITY_TMPDIR=$SINGULARITY_CACHEDIR/tmp - export SINGULARITY_PULLDIR=$SINGULARITY_CACHEDIR/pull - export SINGULARITY_BIND=$SING_BIND - export SINGULARITYENV_PYTHONPATH='${VLBI_SCRIPTS}:$PYTHONPATH' -fi - -export CWL_SINGULARITY_CACHE=$APPTAINER_CACHEDIR -export TOIL_CHECK_ENV=True - -# make folder for running toil -WORKDIR=$PWD/workdir -OUTPUT=$PWD/outdir -JOBSTORE=$PWD/jobstore -LOGDIR=$PWD/logs -TMPD=$PWD/tmpdir - -mkdir -p ${TMPD}_interm -mkdir -p $WORKDIR -mkdir -p $OUTPUT -mkdir -p $LOGDIR - -source ${VENV}/bin/activate - -# run toil -toil-cwl-runner \ ---no-read-only \ ---retryCount 0 \ ---singularity \ ---disableCaching \ ---writeLogsFromAllJobs True \ ---logFile full_log.log \ ---writeLogs ${LOGDIR} \ ---outdir ${OUTPUT} \ ---tmp-outdir-prefix ${TMPD}/ \ ---jobStore ${JOBSTORE} \ ---workDir ${WORKDIR} \ ---coordinationDir ${OUTPUT} \ ---tmpdir-prefix ${TMPD}_interm/ \ ---disableAutoDeployment True \ ---bypass-file-store \ ---preserve-entire-environment \ ---batchSystem slurm \ -${CWL_WORKFLOW} ${YAML} - -#--cleanWorkDir never \ --> for testing - -deactivate diff --git a/toil_examples/source_selection.yaml b/toil_examples/source_selection.yaml deleted file mode 100644 index 2f6d58d9..00000000 --- a/toil_examples/source_selection.yaml +++ /dev/null @@ -1,11 +0,0 @@ -msin: - - class: "Directory" - path: "/project/lofarvwf/Share/jdejong/output/ELAIS/ALL_L/ddcal/alldirs/L686962_P17565.ms" - - class: "Directory" - path: "/project/lofarvwf/Share/jdejong/output/ELAIS/ALL_L/ddcal/alldirs/L686962_P16883.ms" -selfcal: - class: "Directory" - path: "/project/lofarvwf/Software/lofar_facet_selfcal" -h5merger: - class: "Directory" - path: "/project/lofarvwf/Software/lofar_helpers"