Skip to content

Commit

Permalink
Add tests for APWLChannel
Browse files Browse the repository at this point in the history
  • Loading branch information
ndaelman committed Aug 21, 2024
1 parent b99c577 commit 7b298a2
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 20 deletions.
56 changes: 37 additions & 19 deletions src/nomad_simulations/schema_packages/basis_set.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from nomad.datamodel.data import ArchiveSection
from nomad.datamodel.datamodel import EntryArchive
from nomad.datamodel.metainfo.annotations import ELNAnnotation
from nomad.metainfo import MEnum, Quantity, SubSection
from nomad.units import ureg
import itertools
import numpy as np
import pint
from scipy import constants as const
from structlog.stdlib import BoundLogger
from typing import Optional, Any, Callable
from typing import Iterable, Optional, Any, Callable

from nomad import utils
from nomad.datamodel.data import ArchiveSection
from nomad.datamodel.datamodel import EntryArchive
from nomad.datamodel.metainfo.annotations import ELNAnnotation
from nomad.metainfo import MEnum, Quantity, SubSection
from nomad.units import ureg

from nomad_simulations.schema_packages.atoms_state import AtomsState
from nomad_simulations.schema_packages.numerical_settings import (
Expand All @@ -17,6 +19,8 @@
)
from nomad_simulations.schema_packages.properties.energies import EnergyContribution

logger = utils.get_logger(__name__)


def check_normalized(func: Callable):
"""
Expand Down Expand Up @@ -259,6 +263,14 @@ def get_n_terms(
else:
return lengths[0]

def _check_non_negative(self, quantity_names: set[str]) -> bool:
"""Check if all elements in the set are non-negative."""
for quantity_name in quantity_names:
if isinstance(quant := self.get(quantity_name), Iterable):
if np.any(np.array(quant) > 0):
return False
return True

def normalize(self, archive: EntryArchive, logger: BoundLogger) -> None:
super().normalize(archive, logger)

Expand All @@ -274,10 +286,12 @@ def normalize(self, archive: EntryArchive, logger: BoundLogger) -> None:
self.n_terms = None

# enforce differential order constraints
if np.any(np.isneginf(self.differential_order)):
logger.error(
'`APWBaseOrbital.differential_order` must be completely non-negative.'
)
if self._check_non_negative({'differential_order'}):
self.differential_order = None # ? appropriate
if logger is not None:
logger.error(
'`APWBaseOrbital.differential_order` must be completely non-negative. Resetting to `None`.'
)

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -397,10 +411,12 @@ def bo_terms_to_type(self, bo_terms: Optional[int]) -> Optional[str]:
@check_normalized
def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None:
super().normalize(archive, logger)
if np.any(np.isneginf(self.boundary_order)):
logger.error(
'`APWLOrbital.boundary_order` must be completely non-negative.'
)
if self._check_non_negative({'boundary_order'}):
self.boundary_order = None # ? appropriate
if logger is not None:
logger.error(
'`APWLOrbital.boundary_order` must be completely non-negative. Resetting to `None`.'
)


class APWLChannel(BasisSet):
Expand Down Expand Up @@ -429,14 +445,16 @@ class APWLChannel(BasisSet):
def _determine_apw(self) -> dict[str, int]:
"""
Produce a count of the APW components in the l-channel.
Invokes `normalize` on `orbitals`.
Invokes `normalize` on `orbitals` to ensure the existence of `type`.
"""
for orb in self.orbitals:
orb.normalize(None, None)
orb.normalize(None, logger)

type_count = {'apw': 0, 'lapw': 0, 'slapw': 0, 'lo': 0, 'other': 0}
for orb in self.orbitals:
if isinstance(orb, APWOrbital) and orb.type.lower() in type_count.keys():
if orb.type is None:
type_count['other'] += 1
elif isinstance(orb, APWOrbital) and orb.type.lower() in type_count.keys():
type_count[orb.type] += 1
elif isinstance(orb, APWLocalOrbital):
type_count['lo'] += 1
Expand Down Expand Up @@ -494,7 +512,7 @@ def _determine_apw(self) -> dict[str, int]:
Invokes `normalize` on `l_channels`.
"""
for l_channel in self.l_channels:
l_channel.normalize(None, None)
l_channel.normalize(None, logger)

type_count: dict[str, int]
if len(self.l_channels) > 0:
Expand Down Expand Up @@ -541,7 +559,7 @@ def _determine_apw(self) -> Optional[str]:
answer, has_plane_wave = '', False
for comp in self.basis_set_components:
if isinstance(comp, MuffinTinRegion):
comp.normalize(None, None)
comp.normalize(None, logger)
type_count = comp._determine_apw()
if sum([type_count[i] for i in ('apw', 'lapw', 'slapw')]) > 0:
if type_count['slapw'] > 0:
Expand Down
35 changes: 34 additions & 1 deletion tests/test_basis_set.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from itertools import chain
from typing import Any, Optional

from nomad_simulations.schema_packages.atoms_state import AtomsState
Expand Down Expand Up @@ -156,7 +157,39 @@ def test_apw_local_orbital(
orb = APWLocalOrbital(
energy_parameter=e,
differential_order=d_o,
boundary_order=d_o,
boundary_order=b_o,
)
assert orb.get_n_terms() == ref_n_terms
assert orb.bo_terms_to_type(orb.boundary_order) == ref_type


@pytest.mark.parametrize(
'ref_count, apw_es, los',
[
([0, 0, 0, 0, 0], [], []),
([0, 0, 0, 0, 2], [[]], [None]),
([1, 0, 0, 0, 0], [[0.0]], []),
([2, 0, 0, 0, 0], 2 * [[0.0]], []),
([0, 1, 0, 0, 0], [2 * [0.0]], []),
([0, 0, 1, 0, 0], [3 * [0.0]], []),
([1, 1, 0, 0, 0], [[0.0], 2 * [0.0]], []),
([1, 1, 1, 0, 0], [[0.0], 2 * [0.0], 3 * [0.0]], []),
([0, 0, 0, 1, 0], [], ['lo']),
([0, 0, 0, 1, 0], [], ['LO']),
([0, 0, 0, 2, 0], [], ['lo', 'custom']),
],
)
def test_apw_l_channel(
ref_count: list[int], apw_es: list[list[float]], los: list[Optional[str]]
):
"""Test the L-channel APW structure."""
ref_keys = ('apw', 'lapw', 'slapw', 'lo', 'other')
l_channel = APWLChannel(
orbitals=list(
chain(
[APWOrbital(energy_parameter=apw_e) for apw_e in apw_es],
[APWLocalOrbital(type=lo) for lo in los],
)
)
)
assert l_channel._determine_apw() == dict(zip(ref_keys, ref_count))

0 comments on commit 7b298a2

Please sign in to comment.