From 46eae5fc5d7981c4e417247121fc08202e5e24de Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Fri, 12 Jul 2024 13:00:03 +0100 Subject: [PATCH] Tidying up: rename and move private metadata-handling methods --- euphonic/spectra.py | 149 +++++++++++++++++--------------------------- 1 file changed, 56 insertions(+), 93 deletions(-) diff --git a/euphonic/spectra.py b/euphonic/spectra.py index ee7d8a756..b3099f143 100644 --- a/euphonic/spectra.py +++ b/euphonic/spectra.py @@ -751,6 +751,52 @@ def select(self, **select_key_values: Union[ return self[selected_indices] + @staticmethod + def _combine_metadata(all_metadata: LineData) -> Metadata: + """ + From a sequence of metadata dictionaries, combines all common + key/value pairs into the top level of a metadata dictionary, + all unmatching key/value pairs are put into the 'line_data' + key, which is a list of metadata dicts for each element in + all_metadata + """ + # This is for combining multiple separate spectrum metadata, + # they shouldn't have line_data + for metadata in all_metadata: + assert 'line_data' not in metadata.keys() + + # Combine all common key/value pairs into new dict + combined_metadata = dict( + set(all_metadata[0].items()).intersection( + *[metadata.items() for metadata in all_metadata[1:]])) + + # Put all other per-spectrum metadata in line_data + line_data = [ + {key: value for key, value in metadata.items() + if key not in combined_metadata} + for metadata in all_metadata + ] + if any(line_data): + combined_metadata['line_data'] = line_data + + return combined_metadata + + def _tidy_metadata(self, indices: Optional[Sequence[int]] = None + ) -> Metadata: + """ + For a metadata dictionary, combines all common key/value + pairs in 'line_data' and puts them in a top-level dictionary. + If indices is supplied, only those indices in 'line_data' are + combined. Unmatching key/value pairs are discarded + """ + line_data = self.metadata.get("line_data", [{}] * len(self)) + if indices is not None: + line_data = [line_data[idx] for idx in indices] + combined_line_data = self._combine_metadata(line_data) + combined_line_data.pop("line_data", None) + return combined_line_data + + class Spectrum1DCollection(collections.abc.Sequence, SpectrumCollectionMixin, Spectrum): """A collection of Spectrum1D with common x_data and x_tick_labels @@ -941,50 +987,6 @@ def _type_check(spectrum): return cls(x_data, y_data, x_tick_labels=x_tick_labels, metadata=metadata) - @staticmethod - def _combine_metadata(all_metadata: Sequence[Dict[str, Union[int, str]]] - ) -> Dict[str, Union[int, str, LineData]]: - """ - From a sequence of metadata dictionaries, combines all common - key/value pairs into the top level of a metadata dictionary, - all unmatching key/value pairs are put into the 'line_data' - key, which is a list of metadata dicts for each element in - all_metadata - """ - # This is for combining multiple separate spectrum metadata, - # they shouldn't have line_data - for metadata in all_metadata: - assert 'line_data' not in metadata.keys() - # Combine all common key/value pairs - combined_metadata = dict( - set(all_metadata[0].items()).intersection( - *[metadata.items() for metadata in all_metadata[1:]])) - # Put all other per-spectrum metadata in line_data - line_data = [] - for i, metadata in enumerate(all_metadata): - sdata = copy.deepcopy(metadata) - for key in combined_metadata.keys(): - sdata.pop(key) - line_data.append(sdata) - if any(line_data): - combined_metadata['line_data'] = line_data - return combined_metadata - - def _combine_line_metadata(self, indices: Optional[Sequence[int]] = None - ) -> Dict[str, Any]: - """ - For a metadata dictionary, combines all common key/value - pairs in 'line_data' and puts them in a top-level dictionary. - If indices is supplied, only those indices in 'line_data' are - combined. Unmatching key/value pairs are discarded - """ - line_data = self.metadata.get('line_data', [{}]*len(self)) - if indices is not None: - line_data = [line_data[idx] for idx in indices] - combined_line_data = self._combine_metadata(line_data) - combined_line_data.pop('line_data', None) - return combined_line_data - def _get_line_data_vals(self, *line_data_keys: str) -> np.ndarray: """ Get value of the key(s) for each element in @@ -1239,6 +1241,14 @@ def group_by(self, *line_data_keys: str) -> T: metadata in 'line_data' not common across all spectra in a group will be discarded """ + # Remove line_data_keys that are not found in top level of metadata: + # these will not be useful for grouping + keys = [key for key in line_data_keys if key not in self.metadata] + + # If there are no keys left, sum everything as one big group and return + if not keys: + return self.from_spectra([self.sum()]) + grouping_dict = _get_unique_elems_and_idx( self._get_line_data_vals(*line_data_keys)) @@ -1247,7 +1257,7 @@ def group_by(self, *line_data_keys: str) -> T: group_metadata['line_data'] = [{}]*len(grouping_dict) for i, idxs in enumerate(grouping_dict.values()): # Look for any common key/values in grouped metadata - group_i_metadata = self._combine_line_metadata(idxs) + group_i_metadata = self._tidy_metadata(idxs) group_metadata['line_data'][i] = group_i_metadata new_y_data[i] = np.sum(self._y_data[idxs], axis=0) new_y_data = new_y_data*ureg(self._internal_y_data_unit).to( @@ -1272,7 +1282,7 @@ def sum(self) -> Spectrum1D: """ metadata = copy.deepcopy(self.metadata) metadata.pop('line_data', None) - metadata.update(self._combine_line_metadata()) + metadata.update(self._tidy_metadata()) summed_y_data = np.sum(self._y_data, axis=0)*ureg( self._internal_y_data_unit).to(self.y_data_unit) return Spectrum1D(np.copy(self.x_data), @@ -1280,53 +1290,6 @@ def sum(self) -> Spectrum1D: x_tick_labels=copy.copy(self.x_tick_labels), metadata=copy.deepcopy(metadata)) - # def select(self, **select_key_values: Union[ - # str, int, Sequence[str], Sequence[int]]) -> T: - # """ - # Select spectra by their keys and values in metadata['line_data'] - - # Parameters - # ---------- - # **select_key_values - # Key-value/values pairs in metadata['line_data'] describing - # which spectra to extract. For example, to select all spectra - # where metadata['line_data']['species'] = 'Na' or 'Cl' use - # spectrum.select(species=['Na', 'Cl']). To select 'Na' and - # 'Cl' spectra where weighting is also coherent, use - # spectrum.select(species=['Na', 'Cl'], weighting='coherent') - - # Returns - # ------- - # selected_spectra - # A Spectrum1DCollection containing the selected spectra - - # Raises - # ------ - # ValueError - # If no matching spectra are found - # """ - # select_val_dict = _get_unique_elems_and_idx( - # self._get_line_data_vals(*select_key_values.keys())) - # for key, value in select_key_values.items(): - # if isinstance(value, (int, str)): - # select_key_values[key] = [value] - # value_combinations = itertools.product(*select_key_values.values()) - # select_idx = np.array([], dtype=np.int32) - # for value_combo in value_combinations: - # try: - # idx = select_val_dict[value_combo] - # # Don't require every combination to match e.g. - # # spec.select(sample=[0, 2], inst=['MAPS', 'MARI']) - # # we don't want to error simply because there are no - # # inst='MAPS' and sample=2 combinations - # except KeyError: - # continue - # select_idx = np.append(select_idx, idx) - # if len(select_idx) == 0: - # raise ValueError(f'No spectra found with matching metadata ' - # f'for {select_key_values}') - # return self[select_idx] - class Spectrum2D(Spectrum): """