From bab02be2cfa9b4a0d97f3ec5020908c34f2b1453 Mon Sep 17 00:00:00 2001 From: jurjen93 Date: Thu, 8 Aug 2024 14:02:20 +0200 Subject: [PATCH] tec weight prop and float weights --- h5_merger.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/h5_merger.py b/h5_merger.py index e8cd755a..d24fce27 100644 --- a/h5_merger.py +++ b/h5_merger.py @@ -18,7 +18,7 @@ from losoto.h5parm import h5parm from losoto.lib_operations import reorderAxes from numpy import zeros, ones, round, unique, array_equal, append, where, isfinite, complex128, expand_dims, \ - pi, array, all, exp, angle, sort, sum, finfo, take, diff, equal, take, transpose, cumsum, insert, abs, asarray, newaxis, argmin, cos, sin + pi, array, all, exp, angle, sort, sum, finfo, take, diff, equal, take, transpose, cumsum, insert, abs, asarray, newaxis, argmin, cos, sin, float32 import os import re from scipy.interpolate import interp1d @@ -1918,15 +1918,18 @@ def add_weights(self): H = tables.open_file(self.h5name_out, 'r+') for solset in H.root._v_groups.keys(): ss = H.root._f_get_child(solset) - for n, soltab in enumerate(ss._v_groups.keys()): + print(ss._v_groups.keys()) + for n, soltab in enumerate(set(list(ss._v_groups.keys())+['tec000'])): print(soltab + ', from:') - st = ss._f_get_child(soltab) + if 'tec' in soltab and \ + soltab not in list(ss._v_groups.keys()): + st = ss._f_get_child(soltab.replace('tec', 'phase')) + else: + st = ss._f_get_child(soltab) shape = st.val.shape weight_out = ones(shape) axes_new = make_utf8(st.val.attrs["AXES"]).split(',') for m, input_h5 in enumerate(self.h5_tables): - - print(input_h5) T = tables.open_file(input_h5) if soltab not in list(T.root._f_get_child(solset)._v_groups.keys()): T.close() @@ -1941,11 +1944,10 @@ def add_weights(self): if self.merge_all_in_one: m = 0 - newvals = self._interp_along_axis(weight, st2.time[:], st.time[:], axes_new.index('time'), fill_value=1.).astype(int) - newvals = self._interp_along_axis(newvals, st2.freq[:], st.freq[:], axes_new.index('freq'), fill_value=1.).astype(int) - + newvals = self._interp_along_axis(weight, st2.time[:], st.time[:], axes_new.index('time'), fill_value=1.).astype(float32) + newvals = self._interp_along_axis(newvals, st2.freq[:], st.freq[:], axes_new.index('freq'), fill_value=1.).astype(float32) - if weight.shape[-2] != 1 and len(weight.shape)==5: + if weight.shape[-2] != 1 and len(weight.shape) == 5: print("Merge multi-dir weights") if weight.shape[-2] != weight_out.shape[-2]: sys.exit("ERROR: multi-dirs do not have equal shape.") @@ -2003,6 +2005,8 @@ def add_weights(self): else: weight_out *= newvals + # print(weight_out[:, -8, ...]) + else: sys.exit('ERROR: Upsampling of weights bug due to unexpected missing axes.\n axes from ' + input_h5 + ': ' + str(axes) + '\n axes from '