Skip to content

Commit

Permalink
Tidying up: rename and move private metadata-handling methods
Browse files Browse the repository at this point in the history
  • Loading branch information
ajjackson committed Jul 12, 2024
1 parent d080074 commit 46eae5f
Showing 1 changed file with 56 additions and 93 deletions.
149 changes: 56 additions & 93 deletions euphonic/spectra.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,52 @@ def select(self, **select_key_values: Union[

return self[selected_indices]

@staticmethod
def _combine_metadata(all_metadata: LineData) -> Metadata:
"""
From a sequence of metadata dictionaries, combines all common
key/value pairs into the top level of a metadata dictionary,
all unmatching key/value pairs are put into the 'line_data'
key, which is a list of metadata dicts for each element in
all_metadata
"""
# This is for combining multiple separate spectrum metadata,
# they shouldn't have line_data
for metadata in all_metadata:
assert 'line_data' not in metadata.keys()

# Combine all common key/value pairs into new dict
combined_metadata = dict(
set(all_metadata[0].items()).intersection(
*[metadata.items() for metadata in all_metadata[1:]]))

# Put all other per-spectrum metadata in line_data
line_data = [
{key: value for key, value in metadata.items()
if key not in combined_metadata}
for metadata in all_metadata
]
if any(line_data):
combined_metadata['line_data'] = line_data

return combined_metadata

def _tidy_metadata(self, indices: Optional[Sequence[int]] = None
) -> Metadata:
"""
For a metadata dictionary, combines all common key/value
pairs in 'line_data' and puts them in a top-level dictionary.
If indices is supplied, only those indices in 'line_data' are
combined. Unmatching key/value pairs are discarded
"""
line_data = self.metadata.get("line_data", [{}] * len(self))
if indices is not None:
line_data = [line_data[idx] for idx in indices]
combined_line_data = self._combine_metadata(line_data)
combined_line_data.pop("line_data", None)
return combined_line_data


class Spectrum1DCollection(collections.abc.Sequence, SpectrumCollectionMixin, Spectrum):
"""A collection of Spectrum1D with common x_data and x_tick_labels
Expand Down Expand Up @@ -941,50 +987,6 @@ def _type_check(spectrum):
return cls(x_data, y_data, x_tick_labels=x_tick_labels,
metadata=metadata)

@staticmethod
def _combine_metadata(all_metadata: Sequence[Dict[str, Union[int, str]]]
) -> Dict[str, Union[int, str, LineData]]:
"""
From a sequence of metadata dictionaries, combines all common
key/value pairs into the top level of a metadata dictionary,
all unmatching key/value pairs are put into the 'line_data'
key, which is a list of metadata dicts for each element in
all_metadata
"""
# This is for combining multiple separate spectrum metadata,
# they shouldn't have line_data
for metadata in all_metadata:
assert 'line_data' not in metadata.keys()
# Combine all common key/value pairs
combined_metadata = dict(
set(all_metadata[0].items()).intersection(
*[metadata.items() for metadata in all_metadata[1:]]))
# Put all other per-spectrum metadata in line_data
line_data = []
for i, metadata in enumerate(all_metadata):
sdata = copy.deepcopy(metadata)
for key in combined_metadata.keys():
sdata.pop(key)
line_data.append(sdata)
if any(line_data):
combined_metadata['line_data'] = line_data
return combined_metadata

def _combine_line_metadata(self, indices: Optional[Sequence[int]] = None
) -> Dict[str, Any]:
"""
For a metadata dictionary, combines all common key/value
pairs in 'line_data' and puts them in a top-level dictionary.
If indices is supplied, only those indices in 'line_data' are
combined. Unmatching key/value pairs are discarded
"""
line_data = self.metadata.get('line_data', [{}]*len(self))
if indices is not None:
line_data = [line_data[idx] for idx in indices]
combined_line_data = self._combine_metadata(line_data)
combined_line_data.pop('line_data', None)
return combined_line_data

def _get_line_data_vals(self, *line_data_keys: str) -> np.ndarray:
"""
Get value of the key(s) for each element in
Expand Down Expand Up @@ -1239,6 +1241,14 @@ def group_by(self, *line_data_keys: str) -> T:
metadata in 'line_data' not common across all spectra in a
group will be discarded
"""
# Remove line_data_keys that are not found in top level of metadata:
# these will not be useful for grouping
keys = [key for key in line_data_keys if key not in self.metadata]

# If there are no keys left, sum everything as one big group and return
if not keys:
return self.from_spectra([self.sum()])

grouping_dict = _get_unique_elems_and_idx(
self._get_line_data_vals(*line_data_keys))

Expand All @@ -1247,7 +1257,7 @@ def group_by(self, *line_data_keys: str) -> T:
group_metadata['line_data'] = [{}]*len(grouping_dict)
for i, idxs in enumerate(grouping_dict.values()):
# Look for any common key/values in grouped metadata
group_i_metadata = self._combine_line_metadata(idxs)
group_i_metadata = self._tidy_metadata(idxs)
group_metadata['line_data'][i] = group_i_metadata
new_y_data[i] = np.sum(self._y_data[idxs], axis=0)
new_y_data = new_y_data*ureg(self._internal_y_data_unit).to(
Expand All @@ -1272,61 +1282,14 @@ def sum(self) -> Spectrum1D:
"""
metadata = copy.deepcopy(self.metadata)
metadata.pop('line_data', None)
metadata.update(self._combine_line_metadata())
metadata.update(self._tidy_metadata())
summed_y_data = np.sum(self._y_data, axis=0)*ureg(
self._internal_y_data_unit).to(self.y_data_unit)
return Spectrum1D(np.copy(self.x_data),
summed_y_data,
x_tick_labels=copy.copy(self.x_tick_labels),
metadata=copy.deepcopy(metadata))

# def select(self, **select_key_values: Union[
# str, int, Sequence[str], Sequence[int]]) -> T:
# """
# Select spectra by their keys and values in metadata['line_data']

# Parameters
# ----------
# **select_key_values
# Key-value/values pairs in metadata['line_data'] describing
# which spectra to extract. For example, to select all spectra
# where metadata['line_data']['species'] = 'Na' or 'Cl' use
# spectrum.select(species=['Na', 'Cl']). To select 'Na' and
# 'Cl' spectra where weighting is also coherent, use
# spectrum.select(species=['Na', 'Cl'], weighting='coherent')

# Returns
# -------
# selected_spectra
# A Spectrum1DCollection containing the selected spectra

# Raises
# ------
# ValueError
# If no matching spectra are found
# """
# select_val_dict = _get_unique_elems_and_idx(
# self._get_line_data_vals(*select_key_values.keys()))
# for key, value in select_key_values.items():
# if isinstance(value, (int, str)):
# select_key_values[key] = [value]
# value_combinations = itertools.product(*select_key_values.values())
# select_idx = np.array([], dtype=np.int32)
# for value_combo in value_combinations:
# try:
# idx = select_val_dict[value_combo]
# # Don't require every combination to match e.g.
# # spec.select(sample=[0, 2], inst=['MAPS', 'MARI'])
# # we don't want to error simply because there are no
# # inst='MAPS' and sample=2 combinations
# except KeyError:
# continue
# select_idx = np.append(select_idx, idx)
# if len(select_idx) == 0:
# raise ValueError(f'No spectra found with matching metadata '
# f'for {select_key_values}')
# return self[select_idx]


class Spectrum2D(Spectrum):
"""
Expand Down

0 comments on commit 46eae5f

Please sign in to comment.