Skip to content

Commit

Permalink
Adding unit testing for ModelSystem and AtomsState (#15)
Browse files Browse the repository at this point in the history
* Fix ruff version

* Added _check_quantum_numbers in OrbitalsState

Fix typing in atoms_state.py module

Added TODO in HubbardInteractions

Split method resolve_chemical_symbol_and_atomic_number in AtomsState

* Fix methods in model_system.py module

* Added testing for OrbitalsState, AtomsState, ModelSystem

Fix input to normalize to be EntryArchive

* Added testing for utils
  • Loading branch information
JosePizarro3 authored Apr 29, 2024
1 parent 53ba600 commit a7412b2
Show file tree
Hide file tree
Showing 15 changed files with 1,078 additions and 106 deletions.
1 change: 0 additions & 1 deletion .github/workflows/actions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,3 @@ jobs:
- uses: chartboost/ruff-action@v1
with:
args: "format . --check --verbose"
version: 0.1.8
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ dev = [
'pytest==3.10.0',
'pytest-timeout==1.4.2',
'pytest-cov==2.7.1',
'ruff==0.1.8',
'ruff',
"structlog==22.3.0",
"lxml_html_clean>=0.1.0",
]
Expand Down
154 changes: 100 additions & 54 deletions src/nomad_simulations/atoms_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@

import numpy as np
import ase
import pint
from typing import Optional, Union, Dict, Any
from structlog.stdlib import BoundLogger

from nomad.units import ureg

from nomad.metainfo import Quantity, SubSection, MEnum, Section, Context
from nomad.datamodel.data import ArchiveSection
from nomad.datamodel.metainfo.basesections import Entity
Expand All @@ -36,12 +36,9 @@ class OrbitalsState(Entity):
A base section used to define the orbital state of an atom.
"""

# TODO add check for `l_quantum_number` being only 0, 1, 2, 3
# TODO add check for `ml_quantum_number` being only -l, -l+1, ..., l-1, l
# TODO add check for `ms_quantum_number` being only -0.5 or 0.5
# TODO add check for `j_quantum_number` and `mj_quantum_number`

# TODO: add the relativistic kappa_quantum_number

n_quantum_number = Quantity(
type=np.int32,
description="""
Expand Down Expand Up @@ -163,6 +160,38 @@ def __init__(self, m_def: Section = None, m_context: Context = None, **kwargs):
'ms_numbers': dict((zip(('down', 'up'), (-0.5, 0.5)))),
}

def _check_quantum_numbers(self, logger: BoundLogger) -> bool:
"""
Checks the physicality of the quantum numbers.
Args:
logger (BoundLogger): The logger to log messages.
Returns:
(bool): True if the quantum numbers are physical, False otherwise.
"""
if self.n_quantum_number is not None and self.n_quantum_number < 1:
logger.error('The `n_quantum_number` must be greater than 0.')
return False
if self.l_quantum_number is not None and self.l_quantum_number < 1:
logger.error('The `l_quantum_number` must be greater than 0.')
return False
if self.ml_quantum_number is not None and (
self.ml_quantum_number < -self.l_quantum_number
or self.ml_quantum_number > self.l_quantum_number
):
logger.error(
'The `ml_quantum_number` must be between `-l_quantum_number` and `l_quantum_number`.'
)
return False
if self.ms_quantum_number is not None and self.ms_quantum_number not in [
-0.5,
0.5,
]:
logger.error('The `ms_quantum_number` must be -0.5 or 0.5.')
return False
return True

def resolve_number_and_symbol(
self, quantum_name: str, quantum_type: str, logger: BoundLogger
) -> Optional[Union[str, int]]:
Expand Down Expand Up @@ -202,9 +231,6 @@ def resolve_number_and_symbol(
self, f'{quantum_name}_quantum_{_countertype_map[quantum_type]}'
)
if other_quantity is None:
logger.debug(
f'Could not find the {quantum_name}_quantum_{quantum_type} countertype {_countertype_map[quantum_type]}.'
)
return None

# If the counterpart exists, then resolve the quantity from the orbitals_map
Expand Down Expand Up @@ -266,6 +292,11 @@ def resolve_degeneracy(self) -> Optional[int]:
def normalize(self, archive, logger) -> None:
super().normalize(archive, logger)

# General checks for physical quantum numbers and symbols
if not self._check_quantum_numbers(logger):
logger.error('The quantum numbers are not physical.')
return

# Resolving the quantum numbers and symbols if not available
for quantum_name in ['l', 'ml', 'ms']:
for quantum_type in ['number', 'symbol']:
Expand All @@ -276,9 +307,8 @@ def normalize(self, archive, logger) -> None:
setattr(self, f'{quantum_name}_quantum_{quantum_type}', quantity)

# Resolve the degeneracy
self.degeneracy = (
self.resolve_degeneracy() if self.degeneracy is None else self.degeneracy
)
if self.degeneracy is None:
self.degeneracy = self.resolve_degeneracy()


class CoreHole(ArchiveSection):
Expand Down Expand Up @@ -314,49 +344,55 @@ class CoreHole(ArchiveSection):
""",
)

def resolve_occupation(self, logger: BoundLogger) -> None:
def resolve_occupation(self, logger: BoundLogger) -> Optional[np.float64]:
"""
Resolves the occupation of the orbital state. The occupation is resolved from the degeneracy
and the number of excited electrons.
Args:
logger (BoundLogger): The logger to log messages.
Returns:
(Optional[np.float64]): The occupation of the active orbital state.
"""
if self.orbital_ref is None or self.n_excited_electrons is None:
logger.warning(
'Cannot resolve occupation without `orbital_ref` or `n_excited_electrons`.'
)
return
return None
if self.orbital_ref.occupation is None:
degeneracy = self.orbital_ref.resolve_degeneracy()
if degeneracy is None:
logger.warning('Cannot resolve occupation without `degeneracy`.')
return
self.orbital_ref.occupation = degeneracy - self.n_excited_electrons
return None
return degeneracy - self.n_excited_electrons
return None

def normalize(self, archive, logger) -> None:
super().normalize(archive, logger)

# Check if n_excited_electrons is between 0 and 1
if 0.0 <= self.n_excited_electrons <= 1.0:
if self.n_excited_electrons < 0 or self.n_excited_electrons > 1:
logger.error('Number of excited electrons must be between 0 and 1.')

# If dscf_state is 'initial', then n_excited_electrons is set to 0
if self.dscf_state == 'initial':
self.n_excited_electrons = None
self.degeneracy = 1
return

# Resolve the occupation of the active orbital state
if self.orbital_ref is not None and self.n_excited_electrons:
if self.orbital_ref is not None:
# If dscf_state is 'initial', then n_excited_electrons is set to 0
if self.dscf_state == 'initial':
self.n_excited_electrons = None
self.orbital_ref.degeneracy = 1
if self.orbital_ref.occupation is None:
self.resolve_occupation(logger)
self.orbital_ref.occupation = self.resolve_occupation(logger)


class HubbardInteractions(ArchiveSection):
"""
A base section to define the Hubbard interactions of the system.
"""

# TODO (@JosePizarro3 note): we need to have checks for when a `ModelSystem` is spin rotational invariant (then we only need to pass `u_interaction` and `j_hunds_coupling` and resolve the other quantities)

n_orbitals = Quantity(
type=np.int32,
description="""
Expand Down Expand Up @@ -458,57 +494,49 @@ class HubbardInteractions(ArchiveSection):
def resolve_u_interactions(self, logger: BoundLogger) -> Optional[tuple]:
"""
Resolves the Hubbard interactions (u_interaction, u_interorbital_interaction, j_hunds_coupling)
from the Slater integrals (F0, F2, F4).
from the Slater integrals (F0, F2, F4) in the units defined for the Quantity.
Args:
logger (BoundLogger): The logger to log messages.
Returns:
(Optional[tuple]): The Hubbard interactions (u_interaction, u_interorbital_interaction, j_hunds_coupling).
"""
if self.slater_integrals is None or len(self.slater_integrals) == 3:
if self.slater_integrals is None or len(self.slater_integrals) != 3:
logger.warning(
'Could not find `slater_integrals` or the length is not three.'
)
return None
return None, None, None
f0 = self.slater_integrals[0]
f2 = self.slater_integrals[1]
f4 = self.slater_integrals[2]
u_interaction = (
((2.0 / 7.0) ** 2)
* (f0 + 5.0 * f2 + 9.0 * f4)
/ (4.0 * np.pi)
* ureg('joule')
)
u_interaction = ((2.0 / 7.0) ** 2) * (f0 + 5.0 * f2 + 9.0 * f4) / (4.0 * np.pi)
u_interorbital_interaction = (
((2.0 / 7.0) ** 2)
* (f0 - 5.0 * f2 + 3.0 * f4 / 2.0)
/ (4.0 * np.pi)
* ureg('joule')
((2.0 / 7.0) ** 2) * (f0 - 5.0 * f2 + 3.0 * f4 / 2.0) / (4.0 * np.pi)
)
j_hunds_coupling = (
((2.0 / 7.0) ** 2)
* (5.0 * f2 + 15.0 * f4 / 4.0)
/ (4.0 * np.pi)
* ureg('joule')
((2.0 / 7.0) ** 2) * (5.0 * f2 + 15.0 * f4 / 4.0) / (4.0 * np.pi)
)
return u_interaction, u_interorbital_interaction, j_hunds_coupling

def resolve_u_effective(self, logger: BoundLogger) -> Optional[np.float64]:
def resolve_u_effective(self, logger: BoundLogger) -> Optional[pint.Quantity]:
"""
Resolves the effective U parameter (u_interaction - j_local_exchange_interaction).
Args:
logger (BoundLogger): The logger to log messages.
Returns:
(Optional[np.float64]): The effective U parameter.
(Optional[pint.Quantity]): The effective U parameter.
"""
if self.u_interaction is None or self.j_local_exchange_interaction is None:
logger.warning(
'Could not find `HubbardInteractions.u_interaction` or `HubbardInteractions.j_local_exchange_interaction`.'
)
if self.u_interaction is None:
logger.warning('Could not find `HubbardInteractions.u_interaction`.')
return None
if self.u_interaction.magnitude < 0.0:
logger.error('The `HubbardInteractions.u_interaction` must be positive.')
return None
if self.j_local_exchange_interaction is None:
self.j_local_exchange_interaction = 0.0 * ureg.eV
return self.u_interaction - self.j_local_exchange_interaction

def normalize(self, archive, logger) -> None:
Expand Down Expand Up @@ -581,31 +609,49 @@ class AtomsState(Entity):
sub_section=HubbardInteractions.m_def, repeats=False
)

def resolve_chemical_symbol_and_number(self, logger: BoundLogger) -> None:
def resolve_chemical_symbol(self, logger: BoundLogger) -> Optional[str]:
"""
Resolves the chemical symbol from the atomic number and viceversa.
Resolves the `chemical_symbol` from the `atomic_number`.
Args:
logger (BoundLogger): The logger to log messages.
Returns:
(Optional[str]): The resolved `chemical_symbol`.
"""
f = lambda x: tuple(map(bool, x))
if f((self.chemical_symbol, self.atomic_number)) == f((None, not None)):
if self.atomic_number is not None:
try:
self.chemical_symbol = ase.data.chemical_symbols[self.atomic_number]
return ase.data.chemical_symbols[self.atomic_number]
except IndexError:
logger.error(
'The `AtomsState.atomic_number` is out of range of the periodic table.'
)
elif f((self.chemical_symbol, self.atomic_number)) == f((not None, None)):
return None

def resolve_atomic_number(self, logger: BoundLogger) -> Optional[int]:
"""
Resolves the `atomic_number` from the `chemical_symbol`.
Args:
logger (BoundLogger): The logger to log messages.
Returns:
(Optional[int]): The resolved `atomic_number`.
"""
if self.chemical_symbol is not None:
try:
self.atomic_number = ase.data.atomic_numbers[self.chemical_symbol]
return ase.data.atomic_numbers[self.chemical_symbol]
except IndexError:
logger.error(
'The `AtomsState.chemical_symbol` is not recognized in the periodic table.'
)
return None

def normalize(self, archive, logger) -> None:
super().normalize(archive, logger)

# Get chemical_symbol from atomic_number and viceversa
self.resolve_chemical_symbol_and_number(logger)
if self.chemical_symbol is None:
self.chemical_symbol = self.resolve_chemical_symbol(logger)
if self.atomic_number is None:
self.atomic_number = self.resolve_atomic_number(logger)
Loading

0 comments on commit a7412b2

Please sign in to comment.