Skip to content

Commit

Permalink
h5 merger rotation
Browse files Browse the repository at this point in the history
  • Loading branch information
jurjen93 committed Jul 10, 2024
1 parent e804c52 commit daba1fe
Show file tree
Hide file tree
Showing 5 changed files with 484 additions and 171 deletions.
107 changes: 107 additions & 0 deletions fits_helpers/plot_baseline_track.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from casacore.tables import table
import sys
import matplotlib.pyplot as plt
from glob import glob
import os

# Set the MPLCONFIGDIR environment variable
os.system('mkdir -p ~/matplotlib_cache')
os.environ['MPLCONFIGDIR'] = os.path.expanduser('~/matplotlib_cache')


def get_station_id(ms):
"""
Get station with corresponding id number
:param:
- ms: measurement set
:return:
- antenna names, IDs
"""

t = table(ms+'::ANTENNA', ack=False)
ants = t.getcol("NAME")
t.close()

t = table(ms+'::FEED', ack=False)
ids = t.getcol("ANTENNA_ID")
t.close()

return ants, ids


def plot_baseline_track(t_final_name: str = None, t_input_names: list = None, baseline='0-1', UV=True, saveas=None):
"""
Plot baseline track
:param:
- t_final_name: table with final name
- t_input_names: tables to compare with
- mappingfiles: baseline mapping files
"""

if len(t_input_names) > 4:
sys.exit("ERROR: Can just plot 4 inputs")

colors = ['red', 'green', 'yellow', 'black']

if not UV:
print("MAKE UW PLOT")

ant1, ant2 = baseline.split('-')

for n, t_input_name in enumerate(t_input_names):

ref_stats, ref_ids = get_station_id(t_final_name)
new_stats, new_ids = get_station_id(t_input_name)

id_map = dict(zip([ref_stats.index(a) for a in new_stats], new_ids))

with table(t_final_name, ack=False) as f:
fsub = f.query(f'ANTENNA1={ant1} AND ANTENNA2={ant2} AND NOT ALL(WEIGHT_SPECTRUM == 0)', columns='UVW')
uvw1 = fsub.getcol("UVW")

with table(t_input_name, ack=False) as f:
fsub = f.query(f'ANTENNA1={id_map[int(ant1)]} AND ANTENNA2={id_map[int(ant2)]} AND NOT ALL(WEIGHT_SPECTRUM == 0)', columns='UVW')
uvw2 = fsub.getcol("UVW")

# Scatter plot for uvw1
if n == 0:
lbl = 'Final MS'
else:
lbl = None

plt.scatter(uvw1[:, 0], uvw1[:, 2] if UV else uvw1[:, 3], label=lbl, color='blue', edgecolor='black', alpha=0.2, s=100, marker='o')

# Scatter plot for uvw2
plt.scatter(uvw2[:, 0], uvw2[:, 2] if UV else uvw2[:, 3], label=f'Dataset {n}', color=colors[n], edgecolor='black', alpha=0.7, s=40, marker='*')


# Adding labels and title
plt.xlabel("U (m)", fontsize=14)
plt.ylabel("V (m)" if UV else "W (m)", fontsize=14)

# Adding grid
plt.grid(True, linestyle='--', alpha=0.6)

# Adding legend
plt.legend(fontsize=12)

plt.xlim(561260, 563520)
plt.ylim(192782, 194622)

plt.tight_layout()

if saveas is None:
plt.show()
else:
plt.savefig(saveas, dpi=150)
plt.close()

plot_baseline_track('test_1.ms', sorted(glob('a*.ms')), '0-71', saveas='test_1.png')
plot_baseline_track('test_2.ms', sorted(glob('a*.ms')), '0-71', saveas='test_2.png')
plot_baseline_track('test_3.ms', sorted(glob('a*.ms')), '0-71', saveas='test_3.png')
plot_baseline_track('test_4.ms', sorted(glob('a*.ms')), '0-71', saveas='test_4.png')
plot_baseline_track('test_6.ms', sorted(glob('a*.ms')), '0-71', saveas='test_6.png')
plot_baseline_track('test_8.ms', sorted(glob('a*.ms')), '0-71', saveas='test_8.png')
106 changes: 99 additions & 7 deletions h5_merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from losoto.h5parm import h5parm
from losoto.lib_operations import reorderAxes
from numpy import zeros, ones, round, unique, array_equal, append, where, isfinite, complex128, expand_dims, \
pi, array, all, exp, angle, sort, sum, finfo, take, diff, equal, take, transpose, cumsum, insert, abs, asarray, newaxis, argmin
pi, array, all, exp, angle, sort, sum, finfo, take, diff, equal, take, transpose, cumsum, insert, abs, asarray, newaxis, argmin, cos, sin
import os
import re
from scipy.interpolate import interp1d
Expand Down Expand Up @@ -271,6 +271,20 @@ def has_integer(input):
except: # dangerous but ok for now ;-)
return False


def has_rotation(h5_table):
"""
Verify if h5parm has a rotation table
"""

with tables.open_file(h5_table) as H:
for solset in H.root._v_groups.keys():
for soltab in H.root._f_get_child(solset)._v_groups.keys():
if 'rotation' in soltab:
return True
return False


def coordinate_distance(c1, c2):
"""
Find distance between sources
Expand All @@ -290,6 +304,8 @@ def coordinate_distance(c1, c2):
c2 = SkyCoord(c2[0], c2[1], unit='degree', frame='icrs')
return c1.separation(c2).to(u.degree).value



class MergeH5:
"""Merge multiple h5 tables"""

Expand Down Expand Up @@ -1368,7 +1384,7 @@ def _DP3_order(self, soltab):
else:
DP3_axes = []

if 'phase' in soltab or ('tec' in soltab and self.convert_tec) or 'rotation' in soltab:
if 'phase' in soltab or ('tec' in soltab and self.convert_tec):
self.phases = reorderAxes(self.phases, self.axes_final, DP3_axes)
elif 'amplitude' in soltab:
self.amplitudes = reorderAxes(self.amplitudes, self.axes_final, DP3_axes)
Expand Down Expand Up @@ -1921,7 +1937,7 @@ def add_weights(self):
weight = reorderAxes(weight, axes, [a for a in axes_new if a in axes])

# important to do the following in case the input tables are not all different directions
m = min(weight_out.shape[axes.index('dir')] - 1, m)
m = min(weight_out.shape[axes_new.index('dir')] - 1, m)
if self.merge_all_in_one:
m = 0

Expand All @@ -1932,12 +1948,10 @@ def add_weights(self):
if weight.shape[-2] != 1 and len(weight.shape)==5:
print("Merge multi-dir weights")
if weight.shape[-2] != weight_out.shape[-2]:
print(weight.shape, weight_out.shape)
sys.exit("ERROR: multi-dirs do not have equal shape.")
if 'pol' not in axes and 'pol' in axes_new:
newvals = expand_dims(newvals, axis=axes_new.index('pol'))
for n in range(weight_out.shape[-1]):
print(weight_out.shape, newvals.shape)
weight_out[..., n] *= newvals[..., -1]

elif weight.ndim != weight_out.ndim: # not the same value shape
Expand Down Expand Up @@ -2495,6 +2509,72 @@ def add_antenna_source_tables(self):
return


def split_rotation(h5_in):
"""
Split rotation from input h5 and return fulljones output matrix h5
See Eq. B1 in https://www.aanda.org/articles/aa/pdf/2019/02/aa33867-18.pdf
"""

os.system(f'cp {h5_in} {h5_in}.rotation.h5')
os.system(f'cp {h5_in} {h5_in}.phase_amp.h5')

with tables.open_file(h5_in + '.phase_amp.h5', 'r+') as H:
H.remove_node("/sol000", "rotation000", recursive=True)

with tables.open_file(h5_in + '.rotation.h5', 'r+') as H:
axes = make_utf8(H.root.sol000.rotation000.val.attrs["AXES"]) + ',pol'
rot = H.root.sol000.rotation000.val[:]
Axx = cos(rot)
Axy = -sin(rot)
Ayx = sin(rot)
Ayy = cos(rot)

phasesxx = angle(Axx)
amplitudesxx = abs(Axx)
phasesxy = angle(Axy)
amplitudesxy = abs(Axy)
phasesyx = angle(Ayx)
amplitudesyx = abs(Ayx)
phasesyy = angle(Ayy)
amplitudesyy = abs(Ayy)

new_shape = list(rot.shape) + [4]
newphase = zeros(new_shape)
newamps = zeros(new_shape)

newphase[..., 0] = phasesxx
newphase[..., 1] = phasesxy
newphase[..., 2] = phasesyx
newphase[..., 3] = phasesyy

newamps[..., 0] = amplitudesxx
newamps[..., 1] = amplitudesxy
newamps[..., 2] = amplitudesyx
newamps[..., 3] = amplitudesyy

H.remove_node("/sol000", "rotation000", recursive=True)
H.remove_node("/sol000/phase000", "val", recursive=True)
H.remove_node("/sol000/amplitude000", "val", recursive=True)
H.remove_node("/sol000/phase000", "weight", recursive=True)
H.remove_node("/sol000/amplitude000", "weight", recursive=True)
H.remove_node("/sol000/phase000", "pol", recursive=True)
H.remove_node("/sol000/amplitude000", "pol", recursive=True)

H.create_array('/sol000/phase000', 'weight', ones(newphase.shape))
H.create_array('/sol000/amplitude000', 'weight', ones(newamps.shape))
H.create_array('/sol000/phase000', 'val', newphase)
H.create_array('/sol000/amplitude000', 'val', newamps)
H.create_array('/sol000/phase000', 'pol', array([b'XX', b'XY', b'YX', b'YY']))
H.create_array('/sol000/amplitude000', 'pol', array([b'XX', b'XY', b'YX', b'YY']))

H.root.sol000.phase000.val.attrs['AXES'] = bytes(axes, 'utf-8')
H.root.sol000.phase000.weight.attrs['AXES'] = bytes(axes, 'utf-8')
H.root.sol000.amplitude000.val.attrs['AXES'] = bytes(axes, 'utf-8')
H.root.sol000.amplitude000.weight.attrs['AXES'] = bytes(axes, 'utf-8')

return h5_in+'.rotation.h5', h5_in + '.phase_amp.h5'


def h5_check(h5):
"""
With this function you can print a summary from the h5 solution file
Expand Down Expand Up @@ -2725,6 +2805,7 @@ def merge_h5(h5_out=None, h5_tables=None, ms_files=None, h5_time_freq=None, conv
if type(h5_tables) == str:
h5_tables = glob(h5_tables)


if h5_out is None:
for h5check in h5_tables:
h5_check(h5check)
Expand All @@ -2740,7 +2821,7 @@ def merge_h5(h5_out=None, h5_tables=None, ms_files=None, h5_time_freq=None, conv
elif h5_out.split('/')[-1] in [f.split('/')[-1] for f in glob(h5_out)]:
os.system('rm {}'.format(h5_out))

# If alternative solset number is given, we will make a temp h5 file that has the alternative solset number because the code runs on sol000 (will be cleaned up in the end)
# If alternative solset number is given, we will make a temp h5 file that has the alternative solset number because the code runs on sol000
if use_solset != 'sol000':
for h5_ind, h5 in enumerate(h5_tables):
temph5 = h5.replace('.h5', '_temph5merger.h5')
Expand All @@ -2749,6 +2830,14 @@ def merge_h5(h5_out=None, h5_tables=None, ms_files=None, h5_time_freq=None, conv
_change_solset(temph5, use_solset, 'sol000')
h5_tables[h5_ind] = temph5

# If rotation table is given, split this from the list of h5s and add tmp file
for n, h5_in in enumerate(h5_tables):
if has_rotation(h5_in):
print("Rotation table in "+h5_in+" splitting table into fulljones for correct merging")
rot_h5, phaseamp_h5 = split_rotation(h5_in)
h5_tables[n] = phaseamp_h5
h5_tables.insert(n+1, rot_h5)

#################################################
#################### MERGING ####################
#################################################
Expand Down Expand Up @@ -2789,14 +2878,17 @@ def merge_h5(h5_out=None, h5_tables=None, ms_files=None, h5_time_freq=None, conv
for st_group in merge.all_soltabs:
if len(st_group) > 0:
for st in st_group:
if 'rotation' in st:
continue
merge.get_model_h5('sol000', st)
merge.merge_tables('sol000', st, min_distance)
if not merge.doublefulljones:
if merge.convert_tec and (('phase' in st_group[0]) or ('tec' in st_group[0])):
# make sure tec is merged in phase only (if convert_tec==True)
merge.create_new_dataset('sol000', 'phase')
else:
merge.create_new_dataset('sol000', st)
if not 'rotation' in st:
merge.create_new_dataset('sol000', st)
if merge.doublefulljones:
merge.matrix_multiplication()
merge.create_new_dataset('sol000', 'phase')
Expand Down
2 changes: 1 addition & 1 deletion ms_helpers/get_longest_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def parse_args():
"""

parser = ArgumentParser(description='Longest baseline length')
parser.add_argument('--ms', help='MS', required=True)
parser.add_argument('ms', help='MS', required=True)
return parser.parse_args()


Expand Down
Loading

0 comments on commit daba1fe

Please sign in to comment.