diff --git a/euphonic/spectra.py b/euphonic/spectra.py index 3626cd29d..ee7d8a756 100644 --- a/euphonic/spectra.py +++ b/euphonic/spectra.py @@ -5,8 +5,9 @@ import math import json from numbers import Integral, Real -from typing import (Any, Callable, Dict, List, Literal, Optional, overload, +from typing import (Any, Callable, Dict, Generator, List, Literal, Optional, overload, Sequence, Tuple, TypeVar, Union, Type) +from typing_extensions import Self import warnings from pint import DimensionalityError, Quantity @@ -666,10 +667,91 @@ def broaden(self: T, x_width, return new_spectrum -LineData = Sequence[Dict[str, Union[str, int]]] +OneLineData = Dict[str, Union[str, int]] +LineData = Sequence[OneLineData] +Metadata = Dict[str, Union[str, int, LineData]] -class Spectrum1DCollection(collections.abc.Sequence, Spectrum): +class SpectrumCollectionMixin: + """Help a collection of spectra work with "line_data" metadata file + + This is a Mixin to be inherited by Spectrum collection classes + + To avoid redundancy, spectrum collections store metadata in the form + + {"key1": value1, "key2", value2, "line_data": [{"key3": value3, ...}, + {"key4": value4, ...}...]} + + - It is not guaranteed that all "lines" carry the same keys + - No key should appear at both top-level and in line-data; any key-value + pair at top level is assumed to apply to all lines + - "lines" can actually correspond to N-D spectra, the notation was devised + for multi-line plots of Spectrum1DCollection and then applied to other + purposes. + + """ + + def iter_metadata(self) -> Generator[OneLineData, None, None]: + """Iterate over metadata dicts of individual spectra from collection""" + common_metadata = dict((key, self.metadata[key]) for key in self.metadata.keys() - set("line_data")) + from itertools import repeat + + line_data = self.metadata.get("line_data") + if line_data is None: + line_data = repeat({}, len(self._z_data)) + + for one_line_data in line_data: + yield common_metadata | one_line_data + + def _select_indices(self, **select_key_values) -> list[int]: + required_metadata = select_key_values.items() + indices = [i for i, row in enumerate(self.iter_metadata()) if required_metadata <= row.items()] + return indices + + def select(self, **select_key_values: Union[ + str, int, Sequence[str], Sequence[int]]) -> Self: + """ + Select spectra by their keys and values in metadata['line_data'] + + Parameters + ---------- + **select_key_values + Key-value/values pairs in metadata['line_data'] describing + which spectra to extract. For example, to select all spectra + where metadata['line_data']['species'] = 'Na' or 'Cl' use + spectrum.select(species=['Na', 'Cl']). To select 'Na' and + 'Cl' spectra where weighting is also coherent, use + spectrum.select(species=['Na', 'Cl'], weighting='coherent') + + Returns + ------- + selected_spectra + A Spectrum1DCollection containing the selected spectra + + Raises + ------ + ValueError + If no matching spectra are found + """ + # Convert all items to sequences of possibilities + select_key_values = dict( + (key, (value,)) if isinstance(value, (int, str)) else (key, value) + for key, value in select_key_values.items() + ) + + # Collect indices that match each combination of values + selected_indices = [] + for value_combination in itertools.product(*select_key_values.values()): + selection = dict(zip(select_key_values.keys(), value_combination)) + selected_indices.extend(self._select_indices(**selection)) + + if not selected_indices: + raise ValueError(f'No spectra found with matching metadata ' + f'for {select_key_values}') + + return self[selected_indices] + +class Spectrum1DCollection(collections.abc.Sequence, SpectrumCollectionMixin, Spectrum): """A collection of Spectrum1D with common x_data and x_tick_labels Intended for convenient storage of band structures, projected DOS @@ -1198,52 +1280,52 @@ def sum(self) -> Spectrum1D: x_tick_labels=copy.copy(self.x_tick_labels), metadata=copy.deepcopy(metadata)) - def select(self, **select_key_values: Union[ - str, int, Sequence[str], Sequence[int]]) -> T: - """ - Select spectra by their keys and values in metadata['line_data'] - - Parameters - ---------- - **select_key_values - Key-value/values pairs in metadata['line_data'] describing - which spectra to extract. For example, to select all spectra - where metadata['line_data']['species'] = 'Na' or 'Cl' use - spectrum.select(species=['Na', 'Cl']). To select 'Na' and - 'Cl' spectra where weighting is also coherent, use - spectrum.select(species=['Na', 'Cl'], weighting='coherent') - - Returns - ------- - selected_spectra - A Spectrum1DCollection containing the selected spectra - - Raises - ------ - ValueError - If no matching spectra are found - """ - select_val_dict = _get_unique_elems_and_idx( - self._get_line_data_vals(*select_key_values.keys())) - for key, value in select_key_values.items(): - if isinstance(value, (int, str)): - select_key_values[key] = [value] - value_combinations = itertools.product(*select_key_values.values()) - select_idx = np.array([], dtype=np.int32) - for value_combo in value_combinations: - try: - idx = select_val_dict[value_combo] - # Don't require every combination to match e.g. - # spec.select(sample=[0, 2], inst=['MAPS', 'MARI']) - # we don't want to error simply because there are no - # inst='MAPS' and sample=2 combinations - except KeyError: - continue - select_idx = np.append(select_idx, idx) - if len(select_idx) == 0: - raise ValueError(f'No spectra found with matching metadata ' - f'for {select_key_values}') - return self[select_idx] + # def select(self, **select_key_values: Union[ + # str, int, Sequence[str], Sequence[int]]) -> T: + # """ + # Select spectra by their keys and values in metadata['line_data'] + + # Parameters + # ---------- + # **select_key_values + # Key-value/values pairs in metadata['line_data'] describing + # which spectra to extract. For example, to select all spectra + # where metadata['line_data']['species'] = 'Na' or 'Cl' use + # spectrum.select(species=['Na', 'Cl']). To select 'Na' and + # 'Cl' spectra where weighting is also coherent, use + # spectrum.select(species=['Na', 'Cl'], weighting='coherent') + + # Returns + # ------- + # selected_spectra + # A Spectrum1DCollection containing the selected spectra + + # Raises + # ------ + # ValueError + # If no matching spectra are found + # """ + # select_val_dict = _get_unique_elems_and_idx( + # self._get_line_data_vals(*select_key_values.keys())) + # for key, value in select_key_values.items(): + # if isinstance(value, (int, str)): + # select_key_values[key] = [value] + # value_combinations = itertools.product(*select_key_values.values()) + # select_idx = np.array([], dtype=np.int32) + # for value_combo in value_combinations: + # try: + # idx = select_val_dict[value_combo] + # # Don't require every combination to match e.g. + # # spec.select(sample=[0, 2], inst=['MAPS', 'MARI']) + # # we don't want to error simply because there are no + # # inst='MAPS' and sample=2 combinations + # except KeyError: + # continue + # select_idx = np.append(select_idx, idx) + # if len(select_idx) == 0: + # raise ValueError(f'No spectra found with matching metadata ' + # f'for {select_key_values}') + # return self[select_idx] class Spectrum2D(Spectrum): diff --git a/tests_and_analysis/test/euphonic_test/test_spectrum1dcollection.py b/tests_and_analysis/test/euphonic_test/test_spectrum1dcollection.py index 24aedbd5e..cca15fe5d 100644 --- a/tests_and_analysis/test/euphonic_test/test_spectrum1dcollection.py +++ b/tests_and_analysis/test/euphonic_test/test_spectrum1dcollection.py @@ -623,7 +623,9 @@ def test_select(self, spectrum_file, select_kwargs, [3, 5]), ('La2Zr2O7_666_coh_incoh_species_append_pdos.json', {'weighting': 'incoherent', 'species': 'O'}, - [3]) + [3]), + ('methane_pdos.json', + {'desc': 'Methane PDOS', 'label': 'H3'}, [2]), ]) def test_select_same_as_indexing(self, spectrum_file, select_kwargs, expected_indices):