Skip to content

Commit

Permalink
Add normalizer for assigning parameters to long names
Browse files Browse the repository at this point in the history
  • Loading branch information
ndaelman committed Dec 20, 2024
1 parent 710c0ae commit 8549eaa
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 32 deletions.
39 changes: 16 additions & 23 deletions src/nomad_parser_plugin_boss/parsers/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,27 @@
)

if TYPE_CHECKING:
from structlog.stdlib import (
BoundLogger,
)
from structlog.stdlib import BoundLogger
from nomad.datamodel.datamodel import EntryArchive

import os
import numpy as np

from boss.bo.results import BOResults
from boss.io.dump import build_query_points
from boss.pp.pp_main import PPMain
from nomad.config import config
from nomad.datamodel.datamodel import EntryArchive

from nomad.datamodel.metainfo.annotations import H5WebAnnotation
from nomad.parsing.file_parser.text_parser import Quantity as TextQuantity
from nomad.parsing.file_parser.text_parser import TextParser
from nomad.parsing.parser import MatchingParser

from nomad_parser_plugin_boss.schema_packages.schema_package import (
generate_slices,
PotentialEnergySurfaceFit,
)

from nomad.config import config
configuration = config.get_plugin_entry_point(
'nomad_parser_plugin_boss.parsers:parser_entry_point'
)
Expand Down Expand Up @@ -91,9 +91,9 @@ def reshaping(target: list, dim_1: int, dim_2: int) -> np.ndarray:
def parse(
self,
mainfile: str,
archive: EntryArchive,
archive: 'EntryArchive',
logger: 'BoundLogger',
child_archives: dict[str, EntryArchive] = None,
child_archives: dict[str, 'EntryArchive'] = None,
) -> None:
logger.info('BossPostProcessingParser.parse', parameter=configuration.parameter)

Expand All @@ -109,13 +109,6 @@ def parse(
pp_model_slice=[1, 2, no_grid_points],
)
bounds = pp.settings.get('bounds', [])
ranks = len(bounds)

@staticmethod
def generate_slices():
for main_rank in range(ranks):
for upper_rank in range(main_rank + 1, ranks):
yield main_rank, upper_rank

@staticmethod
def compute_parameters(rank: int):
Expand All @@ -125,7 +118,7 @@ def compute_parameters(rank: int):
archive.data = PotentialEnergySurfaceFit()

# Generate slices
for parameter_counter, rank in enumerate(generate_slices()):
for parameter_counter, rank in enumerate(generate_slices(len(bounds))):
main_rank, upper_rank = rank
mu_all_slices, var_all_slices = [], []
for iteration in range(iter_no):
Expand All @@ -140,20 +133,20 @@ def compute_parameters(rank: int):
) # ? change to local minima

mu, var = res.reconstruct_model(iteration + 1).predict(X)
mu_all_slices.append(
mu.reshape(no_grid_points, no_grid_points)
)
var_all_slices.append(
var.reshape(no_grid_points, no_grid_points)
)
mu_all_slices.append(mu.reshape(no_grid_points, no_grid_points))
var_all_slices.append(var.reshape(no_grid_points, no_grid_points))

# Save slices
slice_path = f'parameter_slices/{parameter_counter}'
section = archive.data.m_setdefault(slice_path)
if parameter_counter == 0:
archive.data.parameter_slices[0].m_annotations['h5web'] = H5WebAnnotation(auxiliary_signals=[])
archive.data.parameter_slices[0].m_annotations['h5web'] = (
H5WebAnnotation(auxiliary_signals=[])
)
else:
archive.data.parameter_slices[0].m_annotations['h5web'].auxiliary_signals.append('../' + slice_path)
archive.data.parameter_slices[0].m_annotations[
'h5web'
].auxiliary_signals.append('../' + slice_path)

section.fitted_values = np.array(mu_all_slices)
section.fitted_stddevs = np.sqrt(var_all_slices)
Expand Down
42 changes: 33 additions & 9 deletions src/nomad_parser_plugin_boss/schema_packages/schema_package.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
from typing import TYPE_CHECKING, Generator

if TYPE_CHECKING:
from nomad.datamodel.datamodel import EntryArchive
from structlog.stdlib import BoundLogger

import numpy as np

from nomad.datamodel.data import Schema, ArchiveSection
from nomad.datamodel.data import Schema
from nomad.datamodel.hdf5 import HDF5Dataset
from nomad.datamodel.metainfo.annotations import (
H5WebAnnotation,
Expand All @@ -12,6 +18,13 @@
m_package = SchemaPackage()


def generate_slices(ranks: int) -> Generator:
"""Produce all possible index pairs defining slices of the parameter space."""
for main_rank in range(ranks):
for upper_rank in range(main_rank + 1, ranks):
yield main_rank, upper_rank


class ParameterSpaceSlice(Schema):
# ! TODO use `PhysicalProperty`
m_def = Section(
Expand All @@ -22,12 +35,12 @@ class ParameterSpaceSlice(Schema):
)

fitted_values = Quantity(
type=HDF5Dataset, unit='eV', a_h5web=H5WebAnnotation(long_name='PES', errors='fitted_stddevs')
type=HDF5Dataset,
unit='eV',
a_h5web=H5WebAnnotation(long_name='PES', errors='fitted_stddevs'),
) # ? units

fitted_stddevs = Quantity(
type=HDF5Dataset, unit='eV'
)
fitted_stddevs = Quantity(type=HDF5Dataset, unit='eV')

parameter_1_values = Quantity(
type=HDF5Dataset, a_h5web=H5WebAnnotation(long_name='')
Expand All @@ -37,17 +50,17 @@ class ParameterSpaceSlice(Schema):
type=HDF5Dataset, a_h5web=H5WebAnnotation(long_name='')
)

blank = Quantity(
type=HDF5Dataset, a_h5web=H5WebAnnotation()
)
blank = Quantity(type=HDF5Dataset, a_h5web=H5WebAnnotation())

def normalize(self, archive, logger):
self.blank = np.array([])


class PotentialEnergySurfaceFit(Schema):
m_def = Section(
a_h5web=H5WebAnnotation(title='Potential Energy Surface Fit', paths=['parameter_slices/0']),
a_h5web=H5WebAnnotation(
title='Potential Energy Surface Fit', paths=['parameter_slices/0']
),
)

parameter_names = Quantity(
Expand All @@ -59,4 +72,15 @@ class PotentialEnergySurfaceFit(Schema):
parameter_slices = SubSection(sub_section=ParameterSpaceSlice.m_def, repeats=True)


def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger'):
if isinstance(self.parameter_names, list):
if len(self.parameter_names) == (n_slices := len(self.parameter_slices)):
for slice_indices, parameter_slice in zip(generate_slices(n_slices), self.parameter_slices):
main_rank, upper_rank = slice_indices
parameter_slice.parameter_1_values.m_annotations['h5web'].long_name = self.parameter_names[main_rank]
parameter_slice.parameter_2_values.m_annotations['h5web'].long_name = self.parameter_names[upper_rank]
else:
logger.warning('Length mismatch between parameter names and slices. Not updating annotations.', n_names=len(self.parameter_names), n_slices=n_slices)


m_package.__init_metainfo__()

0 comments on commit 8549eaa

Please sign in to comment.