From 54db93a720978c8f969b1652122c299ce8c3382f Mon Sep 17 00:00:00 2001 From: Garrett 'Karto' Keating Date: Fri, 13 Oct 2023 22:05:15 -0400 Subject: [PATCH] Adding handling for spectral weights in MirParser --- pyuvdata/uvdata/mir_parser.py | 279 ++++++++++++++++------- pyuvdata/uvdata/tests/test_mir_parser.py | 25 +- 2 files changed, 225 insertions(+), 79 deletions(-) diff --git a/pyuvdata/uvdata/mir_parser.py b/pyuvdata/uvdata/mir_parser.py index 423204934..ec3bdff37 100644 --- a/pyuvdata/uvdata/mir_parser.py +++ b/pyuvdata/uvdata/mir_parser.py @@ -764,6 +764,7 @@ def _convert_raw_to_vis(raw_dict): dtype=np.complex64 ), "flags": sp_raw["data"][::2] == -32768, + "weights": np.ones(len(sp_raw["data"]) >> 1, dtype=np.float32), } for sphid, sp_raw in raw_dict.items() } @@ -968,9 +969,13 @@ def _read_data( } else: data_arr = packdata[start_idx:end_idx] - temp_dict[hid] = {"data": data_arr, "flags": np.isnan(data_arr)} + temp_dict[hid] = { + "data": data_arr, + "flags": np.isnan(data_arr), + "weights": np.ones_like(data_arr), + } - if apply_cal: + if apply_cal and is_cross: temp_dict = self._convert_raw_to_vis(temp_dict) temp_dict = self._apply_compass_solns(self._compass_solns, temp_dict) @@ -1115,7 +1120,7 @@ def _write_auto_data(self, filepath, append_data=False, raise_err=True): ) packdata[inhid].tofile(file) - def apply_tsys(self, invert=False, force=False): + def apply_tsys(self, invert=False, force=False, use_cont_det=True): """ Apply Tsys calibration to the visibilities. @@ -1150,64 +1155,99 @@ def apply_tsys(self, invert=False, force=False): "invert=True to apply the correction first." ) - # Create a dictionary here to map antenna pair + integration time step with - # a sqrt(tsys) value. Note that the last index here is the receiver number, - # which technically has a different keyword under which the system temperatures - # are stored. - tsys_dict = { - (idx, jdx, 0): tsys**0.5 if (tsys > 0 and tsys < 1e5) else 0.0 - for idx, jdx, tsys in zip( - self.eng_data["inhid"], self.eng_data["antenna"], self.eng_data["tsys"] - ) - } - tsys_dict.update( - { - (idx, jdx, 1): tsys**0.5 if (tsys > 0 and tsys < 1e5) else 0.0 - for idx, jdx, tsys in zip( - self.eng_data["inhid"], - self.eng_data["antenna"], - self.eng_data["tsys_rx2"], - ) + if use_cont_det: + # Create a dictionary here to map antenna pair + integration time step with + # a sqrt(tsys) value. Note that the last index here is the receiver number, + # which technically has a different keyword under which the system + # temperatures are stored. + tsys_dict = { + (idx, jdx, 0): tsys**0.5 if (tsys > 0 and tsys < 1e5) else 0.0 + for idx, jdx, tsys in zip(*self.eng_data[("inhid", "antenna", "tsys")]) } - ) + tsys_dict.update( + { + (idx, jdx, 1): tsys**0.5 if (tsys > 0 and tsys < 1e5) else 0.0 + for idx, jdx, tsys in zip( + *self.eng_data[("inhid", "antenna", "tsys_rx2")] + ) + } + ) - # now create a per-blhid SEFD dictionary based on antenna pair, integration - # time step, and receiver pairing. - normal_dict = {} - for blhid, idx, jdx, kdx, ldx, mdx in zip( - self.bl_data["blhid"], - self.bl_data["inhid"], - self.bl_data["iant1"], - self.bl_data["ant1rx"], - self.bl_data["iant2"], - self.bl_data["ant2rx"], - ): - try: - normal_dict[blhid] = (2.0 * self.jypk) * ( - tsys_dict[(idx, jdx, kdx)] * tsys_dict[(idx, ldx, mdx)] - ) - except KeyError: - warnings.warn( - "No tsys for blhid %i found (%i-%i baseline, inhid %i). " - "Baseline record will be flagged." % (blhid, jdx, ldx, idx) - ) + # now create a per-blhid SEFD dictionary based on antenna pair, integration + # time step, and receiver pairing. + normal_dict = {} + for blhid, idx, jdx, kdx, ldx, mdx in zip( + *self.bl_data[("blhid", "inhid", "iant1", "ant1rx", "iant2", "ant2rx")] + ): + try: + normal_dict[blhid] = (2.0 * self.jypk) * ( + tsys_dict[(idx, jdx, kdx)] * tsys_dict[(idx, ldx, mdx)] + ) + except KeyError: + warnings.warn( + "No tsys for blhid %i found (%i-%i baseline, inhid %i). " + "Baseline record will be flagged." % (blhid, jdx, ldx, idx) + ) + + if invert: + for key, value in normal_dict.items(): + if value != 0: + normal_dict[key] = 1.0 / value + + # Finally, multiply the individual spectral records by the SEFD values + # that are in the dictionary. + int_time_dict = dict(self.in_data.get_value(("inhid", "rinteg"))) + for sp_rec in self.sp_data: + vis_dict = self.vis_data[sp_rec["sphid"]] + n_sample = abs(sp_rec["fres"] * 1e6) * int_time_dict[sp_rec["inhid"]] + try: + norm_val = normal_dict[sp_rec["blhid"]] + if norm_val == 0.0: + vis_dict["flags"][:] = True + else: + vis_dict["data"] *= norm_val + if invert: + vis_dict["weights"] *= (norm_val**2.0) / n_sample + else: + vis_dict["weights"] *= n_sample / (norm_val**2.0) + except KeyError: + self.vis_data[sp_rec["sphid"]]["flags"][:] = True + else: + # The "wt" column is calculated as (integ time)/(T_DSB ** 2), but we want + # units of Jy**-2. To do this, we just need to multiply by one of the + # forward gain of the antenna (130 Jy/K for SMA) squared and the channel + # width. The factor of 2**2 (4) arises because we need to convert T_DSB**2 + # to T_SSB**2. Note the 1e6 is there to convert fres from MHz to Hz. + wt_arr = ( + self.sp_data["wt"] + * abs(self.sp_data["fres"]) + * (1e6 * ((self.jypk * 2.0) ** (-2.0))) + ) - if invert: - for key, value in normal_dict.items(): - if value != 0: - normal_dict[key] = 1.0 / value + # For data normalization, we used the "wt" but strip out the integration + # time and take the inverse sqrt to get T_DSB, and then use the forward + # gain (plus 2x for DSB -> SSB) to get values of Jy. + norm_arr = np.zeros_like(wt_arr) + norm_arr = np.reciprocal( + self.sp_data["wt"], where=(wt_arr != 0), out=norm_arr + ) + norm_arr = ( + np.sqrt(norm_arr) + * (self.jypk * 2.0) + / self.in_data.get_value("rinteg", header_key=self.sp_data["inhid"]) + ) - # Finally, multiply the individual spectral records by the SEFD values - # that are in the dictionary. - for sphid, blhid in zip(self.sp_data["sphid"], self.sp_data["blhid"]): - try: - norm_val = normal_dict[blhid] + if invert: + for arr in [norm_arr, wt_arr]: + arr = np.reciprocal(arr, where=(arr != 0), out=arr) + + for sphid, norm_val, wt_val in zip(self.sp_data["sphid"], wt_arr, norm_arr): + vis_dict = self.vis_data[sphid] if norm_val == 0.0: - self.vis_data[sphid]["flags"][:] = True + vis_dict["flags"][:] = True else: - self.vis_data[sphid]["data"] *= norm_val - except KeyError: - self.vis_data[sphid]["flags"][:] = True + vis_dict["data"] *= norm_val + vis_dict["weights"] *= wt_val self._tsys_applied = not invert @@ -1517,6 +1557,7 @@ def unload_data(self, unload_vis=True, unload_raw=True, unload_auto=True): for item in self.vis_data.values(): del item["data"] del item["flags"] + del item["weights"] self.vis_data = None self._tsys_applied = False if unload_raw and self.raw_data is not None: @@ -1528,6 +1569,7 @@ def unload_data(self, unload_vis=True, unload_raw=True, unload_auto=True): for item in self.auto_data.values(): del item["data"] del item["flags"] + del item["weights"] self.auto_data = None def _update_filter(self, update_data=None): @@ -2028,7 +2070,7 @@ def write( self._write_auto_data(filepath, append_data=append_data, raise_err=False) @staticmethod - def _rechunk_data(data_dict, chan_avg_arr, inplace=False): + def _rechunk_data(data_dict, chan_avg_arr, inplace=False, rechunk_weights=False): """ Rechunk regular cross- and auto-correlation spectra. @@ -2051,6 +2093,11 @@ def _rechunk_data(data_dict, chan_avg_arr, inplace=False): If True, entries in `vis_dict` will be updated with spectrally averaged data. If False (default), then the method will construct a new dict that will contain the spectrally averaged data. + rechunk_weights : bool + If True, will normalize the rechunked weights to account for the increased + bandwidth of the individual channels -- needed for accurately calculating + the "absolute" weights. Default is False, as most calls to this function + are done before absolute calibration is applied (e.g., via `apply_tsys`). Returns ------- @@ -2063,17 +2110,17 @@ def _rechunk_data(data_dict, chan_avg_arr, inplace=False): new_data_dict = data_dict if inplace else {} - for chan_avg, (hkey, sp_data) in zip(chan_avg_arr, data_dict.items()): + for chan_avg, (hkey, vis_data) in zip(chan_avg_arr, data_dict.items()): # If there isn't anything to average, we can skip the heavy lifting # and just proceed on to the next record. if chan_avg == 1: if not inplace: - new_data_dict[hkey] = copy.deepcopy(sp_data) + new_data_dict[hkey] = copy.deepcopy(vis_data) continue # Otherwise, we need to first get a handle on which data is "good" # for spectrally averaging over. - good_mask = ~sp_data["flags"].reshape((-1, chan_avg)) + good_mask = ~vis_data["flags"].reshape((-1, chan_avg)) # We need to count the number of valid visibilities that goes into each # new channel, so that we can normalize appropriately later. Note we cast @@ -2091,14 +2138,35 @@ def _rechunk_data(data_dict, chan_avg_arr, inplace=False): # Now take the sum of all valid visibilities, multiplied by the # normalization factor. - temp_vis = ( - sp_data["data"].reshape((-1, chan_avg)).sum(where=good_mask, axis=-1) - * temp_count + temp_vis = temp_count * ( + vis_data["data"].reshape((-1, chan_avg)).sum(where=good_mask, axis=-1) + ) + + # Assuming no weighting applied, we need to calculate the sum of the + # variances for the individual channels to get a per (rechunked) channel + # variance, from which we can tabulate weights. + temp_weights = temp_count * np.reciprocal( + vis_data["weights"].reshape((-1, chan_avg)), where=good_mask + ).sum(where=good_mask, axis=-1) + + # If weighting has already been applied (i.e., not just "nsamples"), then + # we need to do a bit extra accounting here to track the fact that we've + # now upped the bandwidth in this particular channel. + if rechunk_weights: + temp_weights *= temp_count + + # Get the weights back into Jy**-2 units. + temp_weights = np.reciprocal( + temp_weights, where=(temp_weights != 0), out=temp_weights ) # Finally, plug the spectrally averaged data back into the dict, flagging # channels with no valid data. - new_data_dict[hkey] = {"data": temp_vis, "flags": temp_count == 0} + new_data_dict[hkey] = { + "data": temp_vis, + "flags": temp_count == 0, + "weights": temp_weights, + } return new_data_dict @@ -2258,7 +2326,12 @@ def rechunk(self, chan_avg): chan_avg_arr = [chanavg_dict[band] for band in getattr(self, attr)["iband"]] if attr == "sp_data": - self._rechunk_data(self.vis_data, chan_avg_arr, inplace=True) + self._rechunk_data( + self.vis_data, + chan_avg_arr, + inplace=True, + rechunk_weights=self._tsys_applied, + ) self._rechunk_raw(self.raw_data, chan_avg_arr, inplace=True) else: self._rechunk_data(self.auto_data, chan_avg_arr, inplace=True) @@ -2834,7 +2907,7 @@ def _read_compass_solns(self, filename): ).astype( np.complex64 ) # BP gains (3D array) - sefd_arr = np.array(file["sefdArr"]) + sefd_arr = np.array(file["sefdArr"]) ** 2.0 # Parse out the bandpass solutions for each antenna, pol/receiver, and # sideband-chunk combination. @@ -2865,18 +2938,43 @@ def _read_compass_solns(self, filename): if key1[idx_compare:] != key2[idx_compare:]: continue - cal_soln = np.reciprocal( - dict1["cal_data"] * np.conj(dict2["cal_data"]) + # Put together bandpass gains for the visibilities + cal_soln = np.zeros( + len(dict1["cal_data"]), + dtype=np.float32 if key1 == key2 else np.complex64, ) + cal_flags = dict1["cal_flags"] | dict2["cal_flags"] + + # Split the processing here based on autos vs crosses if key1 == key2: - cal_soln = np.abs(cal_soln) + cal_soln = np.reciprocal( + abs(dict1["cal_data"] * dict2["cal_data"]), + where=~cal_flags, + out=cal_soln, + ) + else: + cal_soln = np.reciprocal( + dict1["cal_data"] * np.conj(dict2["cal_data"]), + where=~cal_flags, + out=cal_soln, + ) + + # Now generate re-weighting solns based on per-chanel SEFD + # measurements calculated by COMPASS. + weight_soln = np.zeros_like(dict1["sefd_data"]) + weight_flags = dict1["sefd_flags"] | dict2["sefd_flags"] + weight_soln = np.reciprocal( + dict1["sefd_data"] * dict2["sefd_data"], + where=~weight_flags, + out=weight_soln, + ) new_key = key1[:2] + key2 bp_gains_corr[new_key] = { - "cal_data": cal_soln, - "cal_flags": dict1["cal_flags"] | dict2["cal_flags"], - "sefd_data": dict1["sefd_data"] * dict2["sefd_data"], - "sefd_flags": dict1["sefd_flags"] * dict2["sefd_flags"], + "cal_soln": cal_soln, + "cal_flags": cal_flags, + "weight_soln": weight_soln, + "weight_flags": weight_flags, } compass_soln_dict["bp_gains_corr"] = bp_gains_corr @@ -3049,12 +3147,20 @@ def _apply_compass_solns( cal_soln = bp_soln[(ant1, rx1, ant2, rx2, sb, chunk)] except KeyError: # Flag the soln if either ant1 or ant2 solns are bad. - cal_soln = {"cal_data": 1.0, "cal_flags": True} + cal_soln = { + "cal_soln": 1.0, + "cal_flags": True, + "weight_soln": 0.0, + "weight_flags": True, + } finally: # One way or another, we should have a set of gains solutions that # we can apply now (flagging the data where appropriate). - vis_data[sphid]["data"] *= cal_soln["cal_data"] - vis_data[sphid]["flags"] |= cal_soln["cal_flags"] + vis_data[sphid]["data"] *= cal_soln["cal_soln"] + vis_data[sphid]["weights"] *= cal_soln["weight_soln"] + vis_data[sphid]["flags"] |= ( + cal_soln["cal_flags"] | cal_soln["weight_flags"] + ) if apply_flags: # For the sake of reading/coding, let's assign the two catalogs of flags @@ -3268,6 +3374,7 @@ def _chanshift_vis(vis_dict, shift_tuple_list, flag_adj=True, inplace=False): continue new_vis = np.empty_like(sp_vis["data"]) + new_weights = np.empty_like(sp_vis["weights"]) if shift_kernel is None: # If the shift kernel is None, it means that we only have a coarse @@ -3280,13 +3387,17 @@ def _chanshift_vis(vis_dict, shift_tuple_list, flag_adj=True, inplace=False): if coarse_shift < 0: new_vis[:coarse_shift] = sp_vis["data"][-coarse_shift:] new_flags[:coarse_shift] = sp_vis["flags"][-coarse_shift:] + new_weights[:coarse_shift] = sp_vis["weights"][-coarse_shift:] new_vis[coarse_shift:] = 0.0 new_flags[coarse_shift:] = True + new_weights[coarse_shift:] = 0.0 else: new_vis[coarse_shift:] = sp_vis["data"][:-coarse_shift] new_flags[coarse_shift:] = sp_vis["flags"][:-coarse_shift] + new_weights[coarse_shift:] = sp_vis["weights"][:-coarse_shift] new_vis[:coarse_shift] = 0.0 new_flags[:coarse_shift] = True + new_weights[:coarse_shift] = 0.0 else: # If we have to execute a convolution, then the indexing is a bit more # complicated. We use the "valid" option for convolve below, which will @@ -3316,6 +3427,10 @@ def _chanshift_vis(vis_dict, shift_tuple_list, flag_adj=True, inplace=False): temp_vis[sp_vis["flags"][l_clip:r_clip]] = ( np.complex64(np.nan) if flag_adj else np.complex64(0.0) ) + temp_weights = sp_vis["weights"][l_clip:r_clip].copy() + temp_weights[sp_vis["flags"][l_clip:r_clip]] = ( + np.float32(np.nan) if flag_adj else np.float32(0.0) + ) # For some reason, it's about 5x faster to split this up into real # and imaginary operations. The use of "valid" also speeds this up @@ -3326,21 +3441,31 @@ def _chanshift_vis(vis_dict, shift_tuple_list, flag_adj=True, inplace=False): new_vis.imag[l_edge:r_edge] = np.convolve( temp_vis.imag, shift_kernel, "valid" ) + new_weights[l_edge:r_edge] = np.convolve( + temp_weights, shift_kernel, "valid" + ) # Flag out the values beyond the outer bounds new_vis[:l_edge] = new_vis[r_edge:] = ( np.complex64(np.nan) if flag_adj else np.complex64(0.0) ) + new_weights[:l_edge] = new_weights[r_edge:] = ( + np.float32(np.nan) if flag_adj else np.float32(0.0) + ) # Finally, regenerate the flags array for the dict entry. if flag_adj: new_flags = np.isnan(new_vis) - new_vis[new_flags] = 0.0 + new_vis[new_flags] = new_weights[new_flags] = 0.0 else: new_flags = np.zeros_like(sp_vis["flags"]) new_flags[:l_edge] = new_flags[r_edge:] = True # Update our dict with the new values for this sphid - new_vis_dict[sphid] = {"data": new_vis, "flags": new_flags} + new_vis_dict[sphid] = { + "data": new_vis, + "flags": new_flags, + "weights": new_weights, + } return new_vis_dict diff --git a/pyuvdata/uvdata/tests/test_mir_parser.py b/pyuvdata/uvdata/tests/test_mir_parser.py index 39d9187e7..a512978fa 100644 --- a/pyuvdata/uvdata/tests/test_mir_parser.py +++ b/pyuvdata/uvdata/tests/test_mir_parser.py @@ -953,6 +953,7 @@ def test_rechunk_cross(inplace): 25624: { "data": (np.arange(1024) + np.flip(np.arange(1024) * 1j)), "flags": np.zeros(1024, dtype=bool), + "weights": np.ones(1024, dtype=np.float32), } } check_vals = np.arange(1024) + np.flip(np.arange(1024) * 1j) @@ -965,6 +966,7 @@ def test_rechunk_cross(inplace): assert vis_data.keys() == vis_copy.keys() assert np.all(vis_data[25624]["flags"] == np.zeros(1024, dtype=bool)) assert np.all(vis_data[25624]["data"] == check_vals) + assert np.all(vis_data[25624]["weights"] == np.ones(1024)) # Next, test averaging w/o flags vis_copy = MirParser._rechunk_data(vis_data, [4], inplace=inplace) @@ -974,6 +976,7 @@ def test_rechunk_cross(inplace): assert vis_data.keys() == vis_copy.keys() assert np.all(vis_copy[25624]["flags"] == np.zeros(256, dtype=bool)) assert np.all(vis_copy[25624]["data"] == check_vals) + assert np.all(vis_copy[25624]["weights"] == np.ones(256)) vis_data = vis_copy # Finally, check what happens if we flag data @@ -983,6 +986,7 @@ def test_rechunk_cross(inplace): assert vis_data.keys() == vis_copy.keys() assert np.all(vis_copy[25624]["flags"] == [False, True]) assert np.all(vis_copy[25624]["data"] == [check_vals[0], 0.0]) + assert np.all(vis_copy[25624]["weights"] == [1.0, 0.0]) @pytest.mark.parametrize("inplace", [True, False]) @@ -991,6 +995,7 @@ def test_rechunk_auto(inplace): 8675309: { "data": np.arange(-1024, 1024, dtype=np.float32), "flags": np.zeros(2048, dtype=bool), + "weights": np.ones(2048, dtype=np.float32), } } @@ -999,12 +1004,14 @@ def test_rechunk_auto(inplace): assert (auto_copy is auto_data) == inplace assert auto_data.keys() == auto_copy.keys() assert np.all(auto_copy[8675309]["data"] == np.arange(-1024, 1024)) + assert np.all(auto_copy[8675309]["weights"] == np.ones(2048)) # First up, test no averaging auto_copy = MirParser._rechunk_data(auto_data, [512], inplace=inplace) assert (auto_copy is auto_data) == inplace assert auto_data.keys() == auto_copy.keys() assert np.all(auto_copy[8675309]["data"] == [-768.5, -256.5, 255.5, 767.5]) + assert np.all(auto_copy[8675309]["weights"] == np.ones(4)) @pytest.mark.parametrize( @@ -1512,11 +1519,13 @@ def test_chanshift_vis(check_flags, flag_adj, fwd_dir, inplace): flag_vals = [False] * 4 flag_vals.append(check_flags) flag_vals.extend([False] * 3) - + weight_vals = np.ones(8) + weight_vals[flag_vals] = 0.0 vis_dict = { 456: { "data": np.array(vis_vals, dtype=np.complex64), "flags": np.array(flag_vals, dtype=bool), + "weights": np.array(weight_vals, dtype=np.float32), } } @@ -1529,7 +1538,8 @@ def test_chanshift_vis(check_flags, flag_adj, fwd_dir, inplace): assert new_dict is vis_dict assert np.all(vis_vals == new_dict[456]["data"]) - assert np.all(new_dict[456]["flags"] == flag_vals) + assert np.all(flag_vals == new_dict[456]["flags"]) + assert np.all(weight_vals == new_dict[456]["weights"]) # Now try a simple one-channel shift new_dict = MirParser._chanshift_vis( @@ -1550,9 +1560,14 @@ def test_chanshift_vis(check_flags, flag_adj, fwd_dir, inplace): flag_vals[good_slice] == np.roll(new_dict[456]["flags"], -1 if fwd_dir else 1)[good_slice] ) + assert np.all( + weight_vals[good_slice] + == np.roll(new_dict[456]["weights"], -1 if fwd_dir else 1)[good_slice] + ) assert np.all(new_dict[456]["data"][flag_slice] == 0.0) assert np.all(new_dict[456]["flags"][flag_slice]) + assert np.all(new_dict[456]["weights"][flag_slice] == 0.0) # Refresh the values, in case we are doing this in-place if inplace: @@ -1560,6 +1575,7 @@ def test_chanshift_vis(check_flags, flag_adj, fwd_dir, inplace): 456: { "data": np.array(vis_vals, dtype=np.complex64), "flags": np.array(flag_vals, dtype=bool), + "weights": np.array(weight_vals, dtype=np.float32), } } @@ -1575,18 +1591,23 @@ def test_chanshift_vis(check_flags, flag_adj, fwd_dir, inplace): exp_vals = np.roll(vis_vals, 2 if fwd_dir else -2) exp_flags = np.roll(flag_vals, 2 if fwd_dir else -2) + exp_weights = np.roll(weight_vals, 2 if fwd_dir else -2) exp_vals[None if fwd_dir else -2 : 2 if fwd_dir else None] = 0.0 exp_flags[None if fwd_dir else -2 : 2 if fwd_dir else None] = True + exp_weights[None if fwd_dir else -2 : 2 if fwd_dir else None] = 0.0 mod_slice = slice(4 - (-1 if fwd_dir else 2), 6 - (-1 if fwd_dir else 2)) if flag_adj: exp_flags[mod_slice] = check_flags exp_vals[mod_slice] = 0 if check_flags else [check_val * 0.75, check_val * 0.25] + exp_weights[mod_slice] = 0 if check_flags else 1 else: exp_vals[mod_slice] = [check_val * 0.25, check_val * 0.75] exp_flags[mod_slice] = False + exp_weights[mod_slice] = [0.25, 0.75] assert np.all(new_dict[456]["data"] == exp_vals) assert np.all(new_dict[456]["flags"] == exp_flags) + assert np.all(new_dict[456]["weights"] == exp_weights) @pytest.mark.parametrize(