Skip to content

Commit

Permalink
Create a Mixin for some Spectrum1DCollection methods, rewrite select()
Browse files Browse the repository at this point in the history
- This version of select() should be more robust in dealing with
  parameters that exist in "top level" of metadata dict
- I hope it is also easier to understand
  • Loading branch information
ajjackson committed Jul 10, 2024
1 parent 870a9ef commit d080074
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 50 deletions.
180 changes: 131 additions & 49 deletions euphonic/spectra.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import math
import json
from numbers import Integral, Real
from typing import (Any, Callable, Dict, List, Literal, Optional, overload,
from typing import (Any, Callable, Dict, Generator, List, Literal, Optional, overload,
Sequence, Tuple, TypeVar, Union, Type)
from typing_extensions import Self
import warnings

from pint import DimensionalityError, Quantity
Expand Down Expand Up @@ -666,10 +667,91 @@ def broaden(self: T, x_width,
return new_spectrum


LineData = Sequence[Dict[str, Union[str, int]]]
OneLineData = Dict[str, Union[str, int]]
LineData = Sequence[OneLineData]
Metadata = Dict[str, Union[str, int, LineData]]


class Spectrum1DCollection(collections.abc.Sequence, Spectrum):
class SpectrumCollectionMixin:
"""Help a collection of spectra work with "line_data" metadata file
This is a Mixin to be inherited by Spectrum collection classes
To avoid redundancy, spectrum collections store metadata in the form
{"key1": value1, "key2", value2, "line_data": [{"key3": value3, ...},
{"key4": value4, ...}...]}
- It is not guaranteed that all "lines" carry the same keys
- No key should appear at both top-level and in line-data; any key-value
pair at top level is assumed to apply to all lines
- "lines" can actually correspond to N-D spectra, the notation was devised
for multi-line plots of Spectrum1DCollection and then applied to other
purposes.
"""

def iter_metadata(self) -> Generator[OneLineData, None, None]:
"""Iterate over metadata dicts of individual spectra from collection"""
common_metadata = dict((key, self.metadata[key]) for key in self.metadata.keys() - set("line_data"))
from itertools import repeat

line_data = self.metadata.get("line_data")
if line_data is None:
line_data = repeat({}, len(self._z_data))

for one_line_data in line_data:
yield common_metadata | one_line_data

def _select_indices(self, **select_key_values) -> list[int]:
required_metadata = select_key_values.items()
indices = [i for i, row in enumerate(self.iter_metadata()) if required_metadata <= row.items()]
return indices

def select(self, **select_key_values: Union[
str, int, Sequence[str], Sequence[int]]) -> Self:
"""
Select spectra by their keys and values in metadata['line_data']
Parameters
----------
**select_key_values
Key-value/values pairs in metadata['line_data'] describing
which spectra to extract. For example, to select all spectra
where metadata['line_data']['species'] = 'Na' or 'Cl' use
spectrum.select(species=['Na', 'Cl']). To select 'Na' and
'Cl' spectra where weighting is also coherent, use
spectrum.select(species=['Na', 'Cl'], weighting='coherent')
Returns
-------
selected_spectra
A Spectrum1DCollection containing the selected spectra
Raises
------
ValueError
If no matching spectra are found
"""
# Convert all items to sequences of possibilities
select_key_values = dict(
(key, (value,)) if isinstance(value, (int, str)) else (key, value)
for key, value in select_key_values.items()
)

# Collect indices that match each combination of values
selected_indices = []
for value_combination in itertools.product(*select_key_values.values()):
selection = dict(zip(select_key_values.keys(), value_combination))
selected_indices.extend(self._select_indices(**selection))

if not selected_indices:
raise ValueError(f'No spectra found with matching metadata '
f'for {select_key_values}')

return self[selected_indices]

class Spectrum1DCollection(collections.abc.Sequence, SpectrumCollectionMixin, Spectrum):
"""A collection of Spectrum1D with common x_data and x_tick_labels
Intended for convenient storage of band structures, projected DOS
Expand Down Expand Up @@ -1198,52 +1280,52 @@ def sum(self) -> Spectrum1D:
x_tick_labels=copy.copy(self.x_tick_labels),
metadata=copy.deepcopy(metadata))

def select(self, **select_key_values: Union[
str, int, Sequence[str], Sequence[int]]) -> T:
"""
Select spectra by their keys and values in metadata['line_data']
Parameters
----------
**select_key_values
Key-value/values pairs in metadata['line_data'] describing
which spectra to extract. For example, to select all spectra
where metadata['line_data']['species'] = 'Na' or 'Cl' use
spectrum.select(species=['Na', 'Cl']). To select 'Na' and
'Cl' spectra where weighting is also coherent, use
spectrum.select(species=['Na', 'Cl'], weighting='coherent')
Returns
-------
selected_spectra
A Spectrum1DCollection containing the selected spectra
Raises
------
ValueError
If no matching spectra are found
"""
select_val_dict = _get_unique_elems_and_idx(
self._get_line_data_vals(*select_key_values.keys()))
for key, value in select_key_values.items():
if isinstance(value, (int, str)):
select_key_values[key] = [value]
value_combinations = itertools.product(*select_key_values.values())
select_idx = np.array([], dtype=np.int32)
for value_combo in value_combinations:
try:
idx = select_val_dict[value_combo]
# Don't require every combination to match e.g.
# spec.select(sample=[0, 2], inst=['MAPS', 'MARI'])
# we don't want to error simply because there are no
# inst='MAPS' and sample=2 combinations
except KeyError:
continue
select_idx = np.append(select_idx, idx)
if len(select_idx) == 0:
raise ValueError(f'No spectra found with matching metadata '
f'for {select_key_values}')
return self[select_idx]
# def select(self, **select_key_values: Union[
# str, int, Sequence[str], Sequence[int]]) -> T:
# """
# Select spectra by their keys and values in metadata['line_data']

# Parameters
# ----------
# **select_key_values
# Key-value/values pairs in metadata['line_data'] describing
# which spectra to extract. For example, to select all spectra
# where metadata['line_data']['species'] = 'Na' or 'Cl' use
# spectrum.select(species=['Na', 'Cl']). To select 'Na' and
# 'Cl' spectra where weighting is also coherent, use
# spectrum.select(species=['Na', 'Cl'], weighting='coherent')

# Returns
# -------
# selected_spectra
# A Spectrum1DCollection containing the selected spectra

# Raises
# ------
# ValueError
# If no matching spectra are found
# """
# select_val_dict = _get_unique_elems_and_idx(
# self._get_line_data_vals(*select_key_values.keys()))
# for key, value in select_key_values.items():
# if isinstance(value, (int, str)):
# select_key_values[key] = [value]
# value_combinations = itertools.product(*select_key_values.values())
# select_idx = np.array([], dtype=np.int32)
# for value_combo in value_combinations:
# try:
# idx = select_val_dict[value_combo]
# # Don't require every combination to match e.g.
# # spec.select(sample=[0, 2], inst=['MAPS', 'MARI'])
# # we don't want to error simply because there are no
# # inst='MAPS' and sample=2 combinations
# except KeyError:
# continue
# select_idx = np.append(select_idx, idx)
# if len(select_idx) == 0:
# raise ValueError(f'No spectra found with matching metadata '
# f'for {select_key_values}')
# return self[select_idx]


class Spectrum2D(Spectrum):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,9 @@ def test_select(self, spectrum_file, select_kwargs,
[3, 5]),
('La2Zr2O7_666_coh_incoh_species_append_pdos.json',
{'weighting': 'incoherent', 'species': 'O'},
[3])
[3]),
('methane_pdos.json',
{'desc': 'Methane PDOS', 'label': 'H3'}, [2]),
])
def test_select_same_as_indexing(self, spectrum_file, select_kwargs,
expected_indices):
Expand Down

0 comments on commit d080074

Please sign in to comment.