From e3fbdca0bd72bd7a0df56a324dcab0ed3181daec Mon Sep 17 00:00:00 2001 From: Garrett 'Karto' Keating Date: Mon, 4 Mar 2024 20:46:59 -0500 Subject: [PATCH] Adding support for CORRECTED_DATA, MODEL_DATA read/write to CASA MS --- pyuvdata/ms_utils.py | 34 +++++++++++++++- pyuvdata/uvdata/ms.py | 94 ++++++++++++++++++++++++++++--------------- 2 files changed, 95 insertions(+), 33 deletions(-) diff --git a/pyuvdata/ms_utils.py b/pyuvdata/ms_utils.py index a584623d96..3156952043 100644 --- a/pyuvdata/ms_utils.py +++ b/pyuvdata/ms_utils.py @@ -1303,7 +1303,7 @@ def write_ms_polarization( pol_table.putcell("NUM_CORR", 0, len(polarization_array)) -def init_ms_file(filepath): +def init_ms_file(filepath, make_model_col=False, make_corr_col=False): """ Create a skeleton MS dataset to fill. @@ -1311,6 +1311,12 @@ def init_ms_file(filepath): ---------- filepath : str Path to MS to be created. + make_model_col : bool + If set to True, will construct a measurement set that contains a MODEL_DATA + column in addition to the DATA column. Default is False. + make_model_col : bool + If set to True, will construct a measurement set that contains a CORRECTED_DATA + column in addition to the DATA column. Default is False. """ # The required_ms_desc returns the defaults for a CASA MS table ms_desc = tables.required_ms_desc("MAIN") @@ -1395,6 +1401,32 @@ def init_ms_file(filepath): del datacoldesc["desc"]["shape"] ms_desc.update(tables.maketabdesc(datacoldesc)) + if make_model_col: + datacoldesc = tables.makearrcoldesc( + "MODEL_DATA", + 0.0 + 0.0j, + valuetype="complex", + ndim=2, + datamanagertype="TiledShapeStMan", + datamanagergroup="TiledData", + comment="The data column", + ) + del datacoldesc["desc"]["shape"] + ms_desc.update(tables.maketabdesc(datacoldesc)) + + if make_corr_col: + datacoldesc = tables.makearrcoldesc( + "CORRECTED_DATA", + 0.0 + 0.0j, + valuetype="complex", + ndim=2, + datamanagertype="TiledShapeStMan", + datamanagergroup="TiledData", + comment="The data column", + ) + del datacoldesc["desc"]["shape"] + ms_desc.update(tables.maketabdesc(datacoldesc)) + # Now create a column for the weight spectrum, which we plug nsample_array into weightspeccoldesc = tables.makearrcoldesc( "WEIGHT_SPECTRUM", diff --git a/pyuvdata/uvdata/ms.py b/pyuvdata/uvdata/ms.py index d32437909a..e2ceef0c61 100644 --- a/pyuvdata/uvdata/ms.py +++ b/pyuvdata/uvdata/ms.py @@ -47,6 +47,8 @@ def write_ms( filepath, *, force_phase=False, + model_data=None, + corrected_data=None, flip_conj=False, clobber=False, run_check=True, @@ -66,6 +68,14 @@ def write_ms( force_phase : bool Option to automatically phase drift scan data to zenith of the first timestamp. + model_data : ndarray + Optional argument, which contains data to be written into the MODEL_DATA + column of the measurement set (along with the data, which is written into + the DATA column). Must contain the same dimensions as `data_array`. + corrected_data : ndarray + Optional argument, which contains data to be written into the CORRECTED_DATA + column of the measurement set (along with the data, which is written into + the DATA column). Must contain the same dimensions as `data_array`. clobber : bool Option to overwrite the file if it already exists. flip_conj : bool @@ -73,7 +83,7 @@ def write_ms( -1) and the visibilities are complex conjugated prior to write, such that the data are written with the "opposite" conjugation scheme to what UVData normally uses. Note that this is only needed for specific subset of - applications that read MS-formated data, and should only be used by expert + applications that read MS-formatted data, and should only be used by expert users. Default is False. run_check : bool Option to check for the existence and proper shapes of parameters @@ -146,11 +156,34 @@ def write_ms( self._set_scan_numbers() # Initialize a skelton measurement set - ms = ms_utils.init_ms_file(filepath) + ms = ms_utils.init_ms_file( + filepath, + make_model_col=model_data is not None, + make_corr_col=corrected_data is not None, + ) - attr_list = ["data_array", "nsample_array", "flag_array"] + arr_list = [self.data_array, self.nsample_array, self.flag_array] col_list = ["DATA", "WEIGHT_SPECTRUM", "FLAG"] + if model_data is not None: + assert ( + model_data.shape == self.data_array.shape + ), "model_data must have the same shape as data_array." + arr_list.append(model_data) + col_list.append("MODEL_DATA") + if corrected_data is not None: + assert ( + corrected_data.shape == self.data_array.shape + ), "corrected_data must have the same shape as data_array." + arr_list.append(corrected_data) + col_list.append("CORRECTED_DATA") + + if not self.future_array_shapes: + for idx in range(len(arr_list)): + # If using future array shapes, squeeze the arrays now (which just + # returns a view, and therefore doesn't impact mem usage). + arr_list[idx] = np.squeeze(arr_list[idx], axis=1) + # Some tasks in CASA require a band-representative (band-averaged?) value for # the weights and noise for all channels in each row in the MAIN table, which # we will roughly calculate in temp_weights below. @@ -172,14 +205,11 @@ def write_ms( if self.Nspws == 1: # If we only have one spectral window, there is nothing we need to worry # about ordering, so just write the data-related arrays as is to disk - for attr, col in zip(attr_list, col_list): - if self.future_array_shapes: - temp_vals = getattr(self, attr)[:, :, pol_order] - else: - temp_vals = np.squeeze(getattr(self, attr), axis=1)[..., pol_order] + for arr, col in zip(arr_list, col_list): + temp_vals = arr[:, :, pol_order] - if flip_conj and (attr == "data_array"): - temp_vals = np.conj(temp_vals) + if flip_conj and ("DATA" in col): + temp_vals = np.conj(temp_vals, out=temp_vals) ms.putcol(col, temp_vals) @@ -202,7 +232,7 @@ def write_ms( # (n.b., tables.putvarcol can write complex tables like these, but its # slower and more memory-intensive than putcol). - # Since muliple records trace back to a single baseline-time, we use this + # Since multiple records trace back to a single baseline-time, we use this # array to map from arrays that store on a per-record basis to positions # within arrays that record metadata on a per-blt basis. blt_map_array = np.zeros((self.Nblts * self.Nspws), dtype=int) @@ -244,20 +274,15 @@ def write_ms( # Extract out the relevant data out of our data-like arrays that # belong to this scan number. val_dict = {} - for attr, col in zip(attr_list, col_list): - if self.future_array_shapes: - val_dict[col] = getattr(self, attr)[scan_slice] - else: - val_dict[col] = np.squeeze( - getattr(self, attr)[scan_slice], axis=1 - ) + for arr, col in zip(arr_list, col_list): + temp_arr = arr[scan_slice] + + if flip_conj and ("DATA" in col): + temp_arr = np.conjugate(temp_arr) # Have to do this separately since uou can't supply multiple index # arrays at once. - val_dict[col] = val_dict[col][:, :, pol_order] - - if flip_conj: - val_dict["DATA"] = np.conj(val_dict["DATA"]) + val_dict[col] = temp_arr[:, :, pol_order] # This is where the bulk of the heavy lifting is - use the per-spw # channel masks to record one spectral window at a time. @@ -277,7 +302,7 @@ def write_ms( ) last_row += Nrecs - # Now that we have an array to map baselime-time to individual records, + # Now that we have an array to map baseline-time to individual records, # use our indexing array to map various metadata. ant_1_array = self.ant_1_array[blt_map_array] ant_2_array = self.ant_2_array[blt_map_array] @@ -304,7 +329,7 @@ def write_ms( ms.putcol("ANTENNA1", ant_1_array) ms.putcol("ANTENNA2", ant_2_array) - # "INVERVAL" refers to "width" of the window of time time over which data was + # "INTERVAL" refers to "width" of the window of time time over which data was # collected, while "EXPOSURE" is the sum total of integration time. UVData # does not differentiate between these concepts, hence why one array is used # for both values. @@ -378,7 +403,7 @@ def _read_ms_main( The measurement set root directory to read from. data_column : str name of CASA data column to read into data_array. Options are: - 'DATA', 'MODEL', or 'CORRECTED_DATA' + 'DATA', 'MODEL_DATA', or 'CORRECTED_DATA' data_desc_dict : dict Dictionary describing the various rows in the DATA_DESCRIPTION table of an MS file. Keys match to the individual rows, and the values are themselves @@ -435,13 +460,18 @@ def _read_ms_main( if "pyuvdata_xorient" in main_keywords.keys(): self.x_orientation = main_keywords["pyuvdata_xorient"] - default_vis_units = {"DATA": "uncalib", "CORRECTED_DATA": "Jy", "MODEL": "Jy"} + default_vis_units = { + "DATA": "uncalib", + "CORRECTED_DATA": "Jy", + "MODEL": "Jy", + "MODEL_DATA": "Jy", + } # make sure user requests a valid data_column if data_column not in default_vis_units.keys(): raise ValueError( - "Invalid data_column value supplied. Use 'Data','MODEL' or" - " 'CORRECTED_DATA'" + "Invalid data_column value supplied. Use 'DATA','MODEL_DATA', or" + "'CORRECTED_DATA'." ) # set visibility units @@ -491,11 +521,11 @@ def _read_ms_main( data_desc_count = np.sum(np.isin(list(data_desc_dict.keys()), unique_data_desc)) if data_desc_count == 0: - # If there are no records selected, then there isnt a whole lot to do + # If there are no records selected, then there isn't a whole lot to do return None, None, None, None elif data_desc_count == 1: # If we only have a single spectral window, then we can bypass a whole lot - # of slicing and dicing on account of there being a one-to-one releationship + # of slicing and dicing on account of there being a one-to-one relationship # in rows of the MS to the per-blt records of UVData objects. self.time_array = Time( time_arr / 86400.0, format="mjd", scale=timescale.lower() @@ -759,9 +789,9 @@ def _read_ms_main( # Note that this operation has to be split in two because you can only use # advanced slicing on one axis (which both blt_idx and pol_idx require). if flip_conj: - temp_data[:, :, pol_idx] = np.conj(tb_main_sel.getcol("DATA")) + temp_data[:, :, pol_idx] = np.conj(tb_main_sel.getcol(data_column)) else: - temp_data[:, :, pol_idx] = tb_main_sel.getcol("DATA") + temp_data[:, :, pol_idx] = tb_main_sel.getcol(data_column) temp_flags[:, :, pol_idx] = tb_main_sel.getcol("FLAG")