diff --git a/euphonic/spectra.py b/euphonic/spectra.py index b3099f143..ce7589ca0 100644 --- a/euphonic/spectra.py +++ b/euphonic/spectra.py @@ -672,7 +672,7 @@ def broaden(self: T, x_width, Metadata = Dict[str, Union[str, int, LineData]] -class SpectrumCollectionMixin: +class SpectrumCollectionMixin(ABC): """Help a collection of spectra work with "line_data" metadata file This is a Mixin to be inherited by Spectrum collection classes @@ -689,23 +689,124 @@ class SpectrumCollectionMixin: for multi-line plots of Spectrum1DCollection and then applied to other purposes. + The _spectrum_axis class attribute determines which axis property contains + the spectral data, and should be set by subclasses (i.e. to "y" or "z" for + 1D or 2D). """ + # Subclasses must define which axis contains the spectral data for + # purposes of splitting, indexing, etc. + # Python doesn't support abstract class attributes so we define a default + # value, ensuring _something_ was set. + _bin_axes = ("x",) + _spectrum_axis = "y" + _item_type = Spectrum1D + + # Define some private methods which wrap this information into useful forms + def _spectrum_data_name(self) -> str: + return f"{self._spectrum_axis}_data" + + def _spectrum_raw_data_name(self) -> str: + return f"_{self._spectrum_axis}_data" + + def _get_spectrum_data(self) -> Quantity: + return getattr(self, self._spectrum_data_name()) + + def _get_raw_spectrum_data(self) -> np.ndarray: + return getattr(self, self._spectrum_raw_data_name()) + + def _set_spectrum_data(self, data: Quantity) -> None: + setattr(self, self._spectrum_data_name(), data) + + def _set_raw_spectrum_data(self, data: np.ndarray) -> None: + setattr(self, self._spectrum_raw_data_name(), data) + + def _get_spectrum_data_unit(self) -> str: + return getattr(self, f"{self._spectrum_data_name()}_unit") + + def _get_internal_spectrum_data_unit(self) -> str: + return getattr(self, f"_internal_{self._spectrum_data_name()}_unit") + + def _get_bin_kwargs(self) -> Dict[str, Quantity]: + """Get constructor args for bin axes from current data + + e.g. for Spectrum2DCollection this is + + {"x_data": self.x_data, "y_data": self.y_data} + """ + return {f"{axis}_data": getattr(self, f"{axis}_data") + for axis in self._bin_axes} + + def sum(self) -> Spectrum: + """ + Sum collection to a single spectrum + + Returns + ------- + summed_spectrum + A single combined spectrum from all items in collection. Any + metadata in 'line_data' not common across all spectra will be + discarded + """ + metadata = copy.deepcopy(self.metadata) + metadata.pop('line_data', None) + metadata.update(self._tidy_metadata()) + summed_s_data = np.sum(self._get_raw_spectrum_data(), axis=0 + ) * ureg(self._get_internal_spectrum_data_unit() + ).to(self._get_spectrum_data_unit()) + return Spectrum1D( + **self._get_bin_kwargs(), + **{self._spectrum_data_name(): summed_s_data}, + x_tick_labels=copy.copy(self.x_tick_labels), + metadata=metadata + ) + + + # Required methods + @classmethod + @abstractmethod + def from_spectra(cls, spectra: Sequence[Spectrum]) -> Self: ... + + # Mixin methods + def __len__(self): + return self._get_raw_spectrum_data().shape[0] + + def copy(self) -> Self: + """Get an independent copy of spectrum""" + return self._item_type.copy(self) + + def __add__(self, other: Self) -> Self: + """ + Appends the y_data of 2 Spectrum1DCollection objects, + creating a single Spectrum1DCollection that contains + the spectra from both objects. The two objects must + have equal x_data axes, and their y_data must + have compatible units and the same number of y_data + entries + + Any metadata key/value pairs that are common to both + spectra are retained in the top level dictionary, any + others are put in the individual 'line_data' entries + """ + return type(self).from_spectra([*self, *other]) + def iter_metadata(self) -> Generator[OneLineData, None, None]: """Iterate over metadata dicts of individual spectra from collection""" - common_metadata = dict((key, self.metadata[key]) for key in self.metadata.keys() - set("line_data")) - from itertools import repeat + common_metadata = dict( + (key, self.metadata[key]) + for key in self.metadata.keys() - set("line_data")) line_data = self.metadata.get("line_data") if line_data is None: - line_data = repeat({}, len(self._z_data)) + line_data = itertools.repeat({}, len(self._z_data)) for one_line_data in line_data: yield common_metadata | one_line_data def _select_indices(self, **select_key_values) -> list[int]: required_metadata = select_key_values.items() - indices = [i for i, row in enumerate(self.iter_metadata()) if required_metadata <= row.items()] + indices = [i for i, row in enumerate(self.iter_metadata()) + if required_metadata <= row.items()] return indices def select(self, **select_key_values: Union[ @@ -796,8 +897,89 @@ def _tidy_metadata(self, indices: Optional[Sequence[int]] = None combined_line_data.pop("line_data", None) return combined_line_data + def _get_line_data_vals(self, *line_data_keys: str) -> np.ndarray: + """ + Get value of the key(s) for each element in + metadata['line_data']. Returns a 1D array of tuples, where each + tuple contains the value(s) for each key in line_data_keys, for + a single element in metadata['line_data']. This allows easy + grouping/selecting by specific keys + + For example, if we have a Spectrum1DCollection with the following + metadata: + {'desc': 'Quartz', 'line_data': [ + {'inst': 'LET', 'sample': 0, 'index': 1}, + {'inst': 'MAPS', 'sample': 1, 'index': 2}, + {'inst': 'MARI', 'sample': 1, 'index': 1}, + ]} + Then: + _get_line_data_vals('inst', 'sample') = [('LET', 0), + ('MAPS', 1), + ('MARI', 1)] + + Raises a KeyError if 'line_data' or the key doesn't exist + """ + line_data = self.metadata['line_data'] + line_data_vals = np.empty(len(line_data), dtype=object) + for i, data in enumerate(line_data): + line_data_vals[i] = tuple([data[key] for key in line_data_keys]) + return line_data_vals + + def group_by(self, *line_data_keys: str) -> Self: + """ + Group and sum elements of spectral data according to the values + mapped to the specified keys in metadata['line_data'] + + Parameters + ---------- + line_data_keys + The key(s) to group by. If only one line_data_key is + supplied, if the value mapped to a key is the same for + multiple spectra, they are placed in the same group and + summed. If multiple line_data_keys are supplied, the values + must be the same for all specified keys for them to be + placed in the same group + + Returns + ------- + grouped_spectrum + A new Spectrum1DCollection with one line for each group. Any + metadata in 'line_data' not common across all spectra in a + group will be discarded + """ + # Remove line_data_keys that are not found in top level of metadata: + # these will not be useful for grouping + keys = [key for key in line_data_keys if key not in self.metadata] + + # If there are no keys left, sum everything as one big group and return + if not keys: + return self.from_spectra([self.sum()]) + + grouping_dict = _get_unique_elems_and_idx( + self._get_line_data_vals(*line_data_keys)) + + new_s_data = np.zeros((len(grouping_dict), + *self._get_raw_spectrum_data().shape[1:])) + group_metadata = copy.deepcopy(self.metadata) + group_metadata['line_data'] = [{}]*len(grouping_dict) + for i, idxs in enumerate(grouping_dict.values()): + # Look for any common key/values in grouped metadata + group_i_metadata = self._tidy_metadata(idxs) + group_metadata['line_data'][i] = group_i_metadata + new_s_data[i] = np.sum(self._get_raw_spectrum_data()[idxs], axis=0) + new_s_data = new_s_data*ureg(self._get_internal_spectrum_data_unit()).to( + self._get_spectrum_data_unit()) -class Spectrum1DCollection(collections.abc.Sequence, SpectrumCollectionMixin, Spectrum): + new_data = self.copy() + new_data._set_spectrum_data(new_s_data) + new_data.metadata = group_metadata + + return new_data + + +class Spectrum1DCollection(SpectrumCollectionMixin, + Spectrum, + collections.abc.Sequence): """A collection of Spectrum1D with common x_data and x_tick_labels Intended for convenient storage of band structures, projected DOS @@ -831,6 +1013,10 @@ class Spectrum1DCollection(collections.abc.Sequence, SpectrumCollectionMixin, Sp """ T = TypeVar('T', bound='Spectrum1DCollection') + # Private attributes used by SpectrumCollectionMixin + _spectrum_axis = "y" + _item_type = Spectrum1D + def __init__( self, x_data: Quantity, y_data: Quantity, x_tick_labels: Optional[Sequence[Tuple[int, str]]] = None, @@ -882,24 +1068,9 @@ def __init__( f'{len(metadata["line_data"])} entries') self.metadata = {} if metadata is None else metadata - def __add__(self: T, other: T) -> T: - """ - Appends the y_data of 2 Spectrum1DCollection objects, - creating a single Spectrum1DCollection that contains - the spectra from both objects. The two objects must - have equal x_data axes, and their y_data must - have compatible units and the same number of y_data - entries - - Any metadata key/value pairs that are common to both - spectra are retained in the top level dictionary, any - others are put in the individual 'line_data' entries - """ - return type(self).from_spectra([*self, *other]) - def _split_by_indices(self, indices: Union[Sequence[int], np.ndarray] - ) -> List[T]: + ) -> List[Self]: """Split data along x-axis at given indices""" ranges = self._ranges_from_indices(indices) @@ -910,19 +1081,16 @@ def _split_by_indices(self, metadata=self.metadata) for x0, x1 in ranges] - def __len__(self): - return self.y_data.shape[0] - @overload def __getitem__(self, item: int) -> Spectrum1D: ... @overload # noqa: F811 - def __getitem__(self, item: slice) -> T: + def __getitem__(self, item: slice) -> Self: ... @overload # noqa: F811 - def __getitem__(self, item: Union[Sequence[int], np.ndarray]) -> T: + def __getitem__(self, item: Union[Sequence[int], np.ndarray]) -> Self: ... def __getitem__(self, item: Union[int, slice, Sequence[int], np.ndarray] @@ -987,38 +1155,6 @@ def _type_check(spectrum): return cls(x_data, y_data, x_tick_labels=x_tick_labels, metadata=metadata) - def _get_line_data_vals(self, *line_data_keys: str) -> np.ndarray: - """ - Get value of the key(s) for each element in - metadata['line_data']. Returns a 1D array of tuples, where each - tuple contains the value(s) for each key in line_data_keys, for - a single element in metadata['line_data']. This allows easy - grouping/selecting by specific keys - - For example, if we have a Spectrum1DCollection with the following - metadata: - {'desc': 'Quartz', 'line_data': [ - {'inst': 'LET', 'sample': 0, 'index': 1}, - {'inst': 'MAPS', 'sample': 1, 'index': 2}, - {'inst': 'MARI', 'sample': 1, 'index': 1}, - ]} - Then: - _get_line_data_vals('inst', 'sample') = [('LET', 0), - ('MAPS', 1), - ('MARI', 1)] - - Raises a KeyError if 'line_data' or the key doesn't exist - """ - line_data = self.metadata['line_data'] - line_data_vals = np.empty(len(line_data), dtype=object) - for i, data in enumerate(line_data): - line_data_vals[i] = tuple([data[key] for key in line_data_keys]) - return line_data_vals - - def copy(self: T) -> T: - """Get an independent copy of spectrum""" - return Spectrum1D.copy(self) - def to_dict(self) -> Dict[str, Any]: """ Convert to a dictionary consistent with from_dict() @@ -1219,77 +1355,6 @@ def broaden(self: T, else: raise TypeError("x_width must be a Quantity or Callable") - def group_by(self, *line_data_keys: str) -> T: - """ - Group and sum y_data for each spectrum according to the values - mapped to the specified keys in metadata['line_data'] - - Parameters - ---------- - line_data_keys - The key(s) to group by. If only one line_data_key is - supplied, if the value mapped to a key is the same for - multiple spectra, they are placed in the same group and - summed. If multiple line_data_keys are supplied, the values - must be the same for all specified keys for them to be - placed in the same group - - Returns - ------- - grouped_spectrum - A new Spectrum1DCollection with one line for each group. Any - metadata in 'line_data' not common across all spectra in a - group will be discarded - """ - # Remove line_data_keys that are not found in top level of metadata: - # these will not be useful for grouping - keys = [key for key in line_data_keys if key not in self.metadata] - - # If there are no keys left, sum everything as one big group and return - if not keys: - return self.from_spectra([self.sum()]) - - grouping_dict = _get_unique_elems_and_idx( - self._get_line_data_vals(*line_data_keys)) - - new_y_data = np.zeros((len(grouping_dict), self._y_data.shape[-1])) - group_metadata = copy.deepcopy(self.metadata) - group_metadata['line_data'] = [{}]*len(grouping_dict) - for i, idxs in enumerate(grouping_dict.values()): - # Look for any common key/values in grouped metadata - group_i_metadata = self._tidy_metadata(idxs) - group_metadata['line_data'][i] = group_i_metadata - new_y_data[i] = np.sum(self._y_data[idxs], axis=0) - new_y_data = new_y_data*ureg(self._internal_y_data_unit).to( - self.y_data_unit) - - new_data = self.copy() - new_data.y_data = new_y_data - new_data.metadata = group_metadata - - return new_data - - def sum(self) -> Spectrum1D: - """ - Sum y_data over all spectra - - Returns - ------- - summed_spectrum - A Spectrum1D created from the summed y_data. Any metadata - in 'line_data' not common across all spectra will be - discarded - """ - metadata = copy.deepcopy(self.metadata) - metadata.pop('line_data', None) - metadata.update(self._tidy_metadata()) - summed_y_data = np.sum(self._y_data, axis=0)*ureg( - self._internal_y_data_unit).to(self.y_data_unit) - return Spectrum1D(np.copy(self.x_data), - summed_y_data, - x_tick_labels=copy.copy(self.x_tick_labels), - metadata=copy.deepcopy(metadata)) - class Spectrum2D(Spectrum): """