Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
jurjen93 committed Jul 2, 2024
1 parent dfbfee6 commit e804c52
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 80 deletions.
3 changes: 0 additions & 3 deletions fits_helpers/image_utils/fits_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,3 @@ def get_fits_diff(fits_file_1, fits_file_2, out_name='diff.fits'):

hdu = fits.PrimaryHDU(header=f1.header, data=f1.data - f2.data)
hdu.writeto(out_name, overwrite=True)

f1.close()
f2.close()
182 changes: 105 additions & 77 deletions ms_stacker.py → ms_merger.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""
LOFAR UV STACKER
This script can be used to stack measurement sets in the UV plane
LOFAR SIDEREAL VISIBILITY AVERAGER
This script can be used to average visibilities over sidereal time, when using multiple observations of the same FoV.
Example: python ms_stacker.py --msout <MS_NAME> *.ms
The wildcard is in this example stacking a collection of measurement sets
Example: python ms_merger.py --msout <MS_NAME> *.ms
The wildcard is in this example combining a collection of measurement sets
Strategy:
1) Make a template using the 'default_ms' option from casacore.tables (Template class).
Expand All @@ -17,8 +17,8 @@
4) Make exact mapping between input MS and template MS, using only UV data points.
5) Stack measurement sets on the template (Stack class).
The stacking is done with a weighted average, using the FLAG and WEIGHT_SPECTRUM columns.
5) Average measurement sets in the template (Stack class).
The averaging is done with a weighted average, using the FLAG and WEIGHT_SPECTRUM columns.
"""

from casacore.tables import table, default_ms, taql
Expand All @@ -37,6 +37,7 @@
from itertools import compress
import time
import dask.array as da
from dask import delayed
import psutil
from math import ceil
from glob import glob
Expand All @@ -49,7 +50,6 @@
try:
from dask.distributed import Client
use_dask_distributed = True
print('AWESOME: dask.distributed installed, continue with advanced parallel stacking.')
except ImportError:
print('WARNING: dask.distributed not installed, continue without parallel stacking.')
use_dask_distributed = False
Expand Down Expand Up @@ -645,6 +645,25 @@ def load_json(file_path):
return json.load(file)


def remove_flagged_entries(input_table):
"""
Remove flagged entries.
Note that this corrupts the time axis.
"""

output_table = input_table+'.copy.tmp'

# Select rows that do not match the deletion criteria
selected_rows = taql(f'select from {input_table} where not all(WEIGHT_SPECTRUM == 0)')

# Create a new table with the selected rows
selected_rows.copy(output_table, deep=True)

# Replace with input (overwrite)
shutil.rmtree(input_table)
shutil.move(output_table, input_table)


class Template:
"""Make template measurement set based on input measurement sets"""
def __init__(self, msin: list = None, outname: str = 'empty.ms'):
Expand Down Expand Up @@ -781,9 +800,14 @@ def add_stations(self):
tnew_station.flush(True)
tnew_station.close()

def make_template(self, overwrite: bool = True, avg_factor: float = 1):
def make_template(self, overwrite: bool = True, time_res: int = None, avg_factor: float = 1):
"""
Make template MS based on existing MS
:param:
- overwrite: overwrite output file
- time_res: time resolution in seconds
- avg_factor: averaging factor
"""

if overwrite:
Expand Down Expand Up @@ -831,7 +855,10 @@ def process_ms(ms):
self.channels = np.sort(np.expand_dims(np.unique(chan_range), 0))
self.chan_num = self.channels.shape[-1]
# time_range = np.arange(min_t_lst, min(max_t_lst+min_dt, one_lst_day_sec/2 + min_t_lst + min_dt), min_dt/avg_factor)# ensure just half LST day
time_range = np.arange(min_t_lst, max_t_lst + min_dt, min_dt/avg_factor)
if time_res is not None:
time_range = np.arange(min_t_lst, max_t_lst + min_dt, time_res)
else:
time_range = np.arange(min_t_lst, max_t_lst + min_dt, min_dt/avg_factor)
baseline_count = n_baselines(len(self.station_info))
nrows = baseline_count*len(time_range)

Expand Down Expand Up @@ -1142,17 +1169,20 @@ def __init__(self, msin: list = None, outname: str = 'empty.ms', chunkmem: float
self.template = table(outname, readonly=False, ack=False)
self.mslist = msin
self.outname = outname
self.flag = False


num_cpus = psutil.cpu_count(logical=True)
total_memory = psutil.virtual_memory().total / (1024 ** 3) # in GB

if use_dask_distributed:
print(f"CPU number: {num_cpus}\nMemory limit: {int(total_memory)}GB")
print('POSITIVE: dask.distributed installed, continue with advanced parallel stacking.')
print(f"\nCPU number: {num_cpus}\nMemory limit: {int(total_memory)}GB\n")
self.client = Client(n_workers=num_cpus,
threads_per_worker=1,
memory_limit=f'{total_memory/(num_cpus*1.5)}GB')
target_chunk_size = total_memory / num_cpus / chunkmem
self.chunk_size = int(target_chunk_size * (1024 ** 3) / np.dtype(np.float128).itemsize)
self.chunk_size = min(int(target_chunk_size * (1024 ** 3) / np.dtype(np.float128).itemsize), 1_000_000)
else:
target_chunk_size = total_memory / chunkmem
self.chunk_size = min(int(target_chunk_size * (1024 ** 3) / np.dtype(np.float128).itemsize), 1_000_000)
Expand Down Expand Up @@ -1235,12 +1265,6 @@ def read_mapping(mapping_folder):
else:
new_data, _ = get_data_arrays(col, self.T.nrows(), freq_len)

# Scatter data with dask
if use_dask_distributed:
new_data_future = da.from_delayed(self.client.scatter(new_data, broadcast=True), shape=new_data.shape, dtype=new_data.dtype)
if col == 'UVW':
uvw_weights_future = da.from_delayed(self.client.scatter(uvw_weights, broadcast=True), shape=uvw_weights.shape, dtype=uvw_weights.dtype)

# Loop over measurement sets
for ms in self.mslist:

Expand Down Expand Up @@ -1281,31 +1305,21 @@ def read_mapping(mapping_folder):

# Stack columns (using dask for multi-processing)
if col == 'UVW':
if use_dask_distributed:
new_data_dask = da.from_delayed(new_data_future[row_idxs_new, :],
shape=new_data[row_idxs_new, :].shape,
dtype=new_data.dtype)
else:
new_data_dask = da.from_array(new_data[row_idxs_new, :], chunks='auto')
new_data[row_idxs_new, :] = (new_data_dask + data[row_idxs, :]).compute()
new_data_dask_subset = da.from_array(new_data[row_idxs_new, :], chunks='auto')
updated_data = (new_data_dask_subset + data[row_idxs, :]).compute()
new_data[row_idxs_new, :] = updated_data
new_data.flush()

if use_dask_distributed:
uvw_weights_dask = da.from_delayed(uvw_weights_future[row_idxs_new, :],
shape=uvw_weights[row_idxs_new, :].shape,
dtype=uvw_weights.dtype)
else:
uvw_weights_dask = da.from_array(uvw_weights[row_idxs_new, :], chunks='auto')
uvw_weights[row_idxs_new, :] = (uvw_weights_dask + 1).compute()
uvw_weights_dask_subset = da.from_array(uvw_weights[row_idxs_new, :], chunks='auto')
updated_weights = (uvw_weights_dask_subset + 1).compute()
uvw_weights[row_idxs_new, :] = updated_weights
uvw_weights.flush()
else:
if use_dask_distributed:
new_data_dask = da.from_delayed(new_data_future[np.ix_(row_idxs_new, freq_idxs)],
shape=new_data[np.ix_(row_idxs_new, freq_idxs)].shape,
dtype=new_data.dtype)
else:
new_data_dask = da.from_array(new_data[np.ix_(row_idxs_new, freq_idxs)], chunks='auto')
new_data[np.ix_(row_idxs_new, freq_idxs)] = (new_data_dask + data[row_idxs, :]).compute()
new_data_dask_subset = da.from_array(new_data[np.ix_(row_idxs_new, freq_idxs)], chunks='auto')
updated_data = (new_data_dask_subset + data[row_idxs, :]).compute()
new_data[np.ix_(row_idxs_new, freq_idxs)] = updated_data
new_data.flush()

new_data.flush()


Expand All @@ -1324,14 +1338,18 @@ def read_mapping(mapping_folder):

self.T.close()

if self.flag:
# ADD FLAG
print(f'Put column FLAG')
taql(f'UPDATE {self.outname} SET FLAG = (WEIGHT_SPECTRUM == 0)')
else:
# REMOVE FLAGS
remove_flagged_entries(self.outname)

# NORM DATA
print(f'Normalise column DATA')
taql(f'UPDATE {self.outname} SET DATA = (DATA / WEIGHT_SPECTRUM) WHERE ANY(WEIGHT_SPECTRUM > 0)')

# ADD FLAG
print(f'Put column FLAG')
taql(f'UPDATE {self.outname} SET FLAG = (WEIGHT_SPECTRUM == 0)')

print("----------\n")


Expand Down Expand Up @@ -1360,7 +1378,7 @@ def check_folder_exists(folder_path):
return os.path.isdir(folder_path)


def plot_baseline_track(t_final_name: str = None, t_input_names: list = None, mappingfiles: list = None, UV=True, saveas=None):
def plot_baseline_track(t_final_name: str = None, t_input_names: list = None, baseline='0-1', UV=True, saveas=None):
"""
Plot baseline track
Expand All @@ -1373,37 +1391,40 @@ def plot_baseline_track(t_final_name: str = None, t_input_names: list = None, ma
if len(t_input_names) > 4:
sys.exit("ERROR: Can just plot 4 inputs")

colors = ['red', 'orange', 'pink', 'brown']
colors = ['red', 'green', 'yellow', 'black']

if not UV:
print("MAKE UW PLOT")

ant1, ant2 = baseline.split('-')

for n, t_input_name in enumerate(t_input_names):

try:
m = open(mappingfiles[n])
mapping = {int(i): int(j) for i, j in json.load(m).items()}
ref_stats, ref_ids = get_station_id(t_final_name)
new_stats, new_ids = get_station_id(t_input_name)

with table(t_final_name, ack=False) as f:
uvw1 = f.getcol("UVW")[list(mapping.values())]
id_map = dict(zip([ref_stats.index(a) for a in new_stats], new_ids))

with table(t_input_name, ack=False) as f:
uvw2 = f.getcol("UVW")[list(mapping.keys())]
with table(t_final_name, ack=False) as f:
fsub = f.query(f'ANTENNA1={ant1} AND ANTENNA2={ant2} AND NOT ALL(WEIGHT_SPECTRUM == 0)', columns='UVW')
uvw1 = fsub.getcol("UVW")

# Scatter plot for uvw1
if n == 0:
lbl = 'Final MS'
else:
lbl = None
with table(t_input_name, ack=False) as f:
fsub = f.query(f'ANTENNA1={id_map[int(ant1)]} AND ANTENNA2={id_map[int(ant2)]} AND NOT ALL(WEIGHT_SPECTRUM == 0)', columns='UVW')
uvw2 = fsub.getcol("UVW")

plt.scatter(uvw1[:, 0], uvw1[:, 2] if UV else uvw1[:, 3], label=lbl, color='blue', edgecolor='black', alpha=0.4, s=100, marker='o')
# plt.plot(uvw1[:, 0], uvw1[:, 2] if UV else uvw1[:, 3], color='blue', alpha=0.4)
# Scatter plot for uvw1
if n == 0:
lbl = 'Final MS'
else:
lbl = None

plt.scatter(uvw1[:, 0], uvw1[:, 2] if UV else uvw1[:, 3], label=lbl, color='blue', edgecolor='black', alpha=0.2, s=100, marker='o')
# plt.plot(uvw1[:, 0], uvw1[:, 2] if UV else uvw1[:, 3], color='blue', alpha=0.4)

# Scatter plot for uvw2
plt.scatter(uvw2[:, 0], uvw2[:, 2] if UV else uvw2[:, 3], label=f'MS {n} to stack', color=colors[n], edgecolor='black', alpha=0.7, s = 40, marker='*')
# Scatter plot for uvw2
plt.scatter(uvw2[:, 0], uvw2[:, 2] if UV else uvw2[:, 3], label=f'Dataset {n}', color=colors[n], edgecolor='black', alpha=0.7, s = 40, marker='*')

except:
pass

# Adding labels and title
plt.xlabel("U (m)", fontsize=14)
Expand All @@ -1415,6 +1436,8 @@ def plot_baseline_track(t_final_name: str = None, t_input_names: list = None, ma
# Adding legend
plt.legend(fontsize=12)

plt.tight_layout()

if saveas is None:
plt.show()
else:
Expand All @@ -1435,24 +1458,22 @@ def make_baseline_uvw_plots(tabl, mslist):

for baseline in baselines:
bl = '-'.join([str(a) for a in baseline])
plot_baseline_track(tabl, sorted(mslist), sorted([ms + '_baseline_mapping/'+ bl+'.json' for ms in mslist]),
saveas=f'baseline_plots/{bl}.png')

return
plot_baseline_track(tabl, sorted(mslist), bl, saveas=f'baseline_plots/{bl}.png')


def parse_args():
"""
Parse input arguments
"""

parser = ArgumentParser(description='MS stacking')
parser.add_argument('msin', nargs='+', help='Measurement sets to stack')
parser = ArgumentParser(description='Sidereal visibility averaging')
parser.add_argument('msin', nargs='+', help='Measurement sets to combine')
parser.add_argument('--msout', type=str, default='empty.ms', help='Measurement set output name')
parser.add_argument('--chunk_mem', type=float, default=4., help='Chunk memory size. Large files need larger parameter, small files can have small parameter value.')
parser.add_argument('--avg', type=float, default=1., help='Additional final frequency and time averaging')
parser.add_argument('--less_avg', type=float, default=1., help='Factor to reduce averaging (only in combination with --advanced_stacking). Helps to speedup stacking, but less accurate results.')
parser.add_argument('--advanced_stacking', action='store_true', help='Increase time resolution during stacking (resulting in larger data volume).')
parser.add_argument('--time_res', type=float, help='Desired time resolution in seconds')
parser.add_argument('--keep_mapping', action='store_true', help='Do not remove mapping files')
parser.add_argument('--record_time', action='store_true', help='Record wall-time of stacking')
parser.add_argument('--no_compression', action='store_true', help='No compression of data')
Expand All @@ -1474,18 +1495,25 @@ def ms_merger():
if check_folder_exists(args.msout):
sys.exit(f"ERROR: {args.msout} already exists! Delete file first if you want to overwrite.")

# Find averaging_factor
if not args.advanced_stacking:
if args.less_avg != 1.:
sys.exit("ERROR: --less_avg only used in combination with --advanced_stacking")
if args.advanced_stacking and args.time_res is not None:
sys.exit("ERROR: --advanced_stacking and --time_res cannot be both given.")

if args.time_res is not None:
avg = 1
print(f"Use time resolution {args.time_res} seconds")
else:
avg = get_avg_factor(args.msin, args.less_avg)
print(f"Intermediate averaging factor {avg}\n"
f"Final averaging factor {max(int(avg * args.avg), 1) if args.advanced_stacking else int(args.avg)}")
# Find averaging_factor
if not args.advanced_stacking:
if args.less_avg != 1.:
sys.exit("ERROR: --less_avg only used in combination with --advanced_stacking")
avg = 1
else:
avg = get_avg_factor(args.msin, args.less_avg)
print(f"Intermediate averaging factor {avg}\n"
f"Final averaging factor {max(int(avg * args.avg), 1) if args.advanced_stacking else int(args.avg)}")

t = Template(args.msin, args.msout)
t.make_template(overwrite=True, avg_factor=avg)
t.make_template(overwrite=True, time_res=args.time_res, avg_factor=avg)
t.make_uvw()
print("\n############\nTemplate creation completed\n############")

Expand All @@ -1502,7 +1530,7 @@ def ms_merger():

# Apply dysco compression
if not args.no_compression:
compress(args.msout) #, max(int(avg * args.avg), 1) if args.advanced_stacking else int(args.avg))
compress(args.msout)

if args.plot_uv_baseline_coverage:
make_baseline_uvw_plots(args.msout, args.msin)
Expand Down

0 comments on commit e804c52

Please sign in to comment.