Skip to content

Commit

Permalink
- Improve on APW structure (may need more test cases)
Browse files Browse the repository at this point in the history
- TODO: test `_determine_apw`
- TODO: test QuickStep
  • Loading branch information
ndaelman committed Aug 19, 2024
1 parent 862f5ee commit 713b904
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 22 deletions.
32 changes: 16 additions & 16 deletions src/nomad_simulations/schema_packages/basis_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from nomad_simulations.schema_packages.properties.energies import EnergyContribution


class BasisSet(ArchiveSection):
class BasisSet(NumericalSettings):
"""A type section denoting a basis set component of a simulation.
Should be used as a base section for more specialized sections.
Allows for denoting the basis set's _scope_, i.e. to which entity it applies,
Expand All @@ -32,6 +32,13 @@ class BasisSet(ArchiveSection):
- atom-centered basis sets, e.g. Gaussian-type basis sets, Slater-type orbitals, muffin-tin orbitals
"""

name = Quantity(
type=str,
description="""
Name of the basis set component.
""",
)

species_scope = Quantity(
type=AtomsState,
shape=['*'],
Expand All @@ -55,8 +62,9 @@ class BasisSet(ArchiveSection):

# ? band_scope or orbital_scope: valence vs core

def __init__(self, m_def: 'Section' = None, m_context: 'Context' = None, **kwargs):
super().__init__(m_def, m_context, **kwargs)
def normalize(self, archive, logger):
super().normalize(archive, logger)
self.name = self.m_def.name


class PlaneWaveBasisSet(BasisSet, Mesh):
Expand Down Expand Up @@ -286,7 +294,7 @@ class APWLChannel(BasisSet):
description="""
Angular momentum quantum number of the local orbital.
""",
) # TODO: add `l` as a quantity
)

n_wavefunctions = Quantity(
type=np.int32,
Expand Down Expand Up @@ -322,8 +330,7 @@ def _determine_apw(self, logger: BoundLogger) -> dict[str, int]:
return count

def normalize(self, archive: EntryArchive, logger: BoundLogger) -> None:
super().normalize(archive, logger)
# self.name = self.m_def.name
super(BasisSet).normalize(archive, logger)
self.n_wavefunctions = len(self.orbitals) * (2 * self.name + 1)


Expand Down Expand Up @@ -366,10 +373,6 @@ def _determine_apw(self, logger: BoundLogger) -> dict[str, int]:
count.update(channel._determine_apw(logger))
return count

def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None:
super().normalize(archive, logger)
# self.type = 'spherical'


class BasisSetContainer(NumericalSettings):
"""
Expand Down Expand Up @@ -447,7 +450,7 @@ def normalize(self, archive: EntryArchive, logger: BoundLogger) -> None:

def generate_apw(
species: dict[str, dict[str, Any]], cutoff: Optional[float] = None
) -> BasisSetContainer:
) -> BasisSetContainer: # TODO: extend to cover all parsing use cases (maybe split up?)
"""
Generate a mock APW basis set with the following structure:
.
Expand All @@ -473,7 +476,7 @@ def generate_apw(
pw = APWPlaneWaveBasisSet(cutoff_energy=cutoff)
basis_set_components.append(pw)

for sp_name, sp in species.items():
for sp_ref, sp in species.items():
sp['r'] = sp.get('r', None)
sp['l_max'] = sp.get('l_max', 0)
sp['orb_type'] = sp.get('orb_type', [])
Expand All @@ -482,9 +485,7 @@ def generate_apw(
basis_set_components.extend(
[
MuffinTinRegion(
species_scope=AtomsState(
chemical_symbol=sp_name
), # TODO: extend to search through a model_system
species_scope=[sp_ref],
radius=sp['r'],
l_max=sp['l_max'],
l_channels=[
Expand All @@ -500,7 +501,6 @@ def generate_apw(
for l in range(sp['l_max'] + 1)
],
)

]
)

Expand Down
17 changes: 13 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,20 +405,29 @@ def k_space_simulation() -> Simulation:


refs_apw = [
{},
{
'm_def': 'nomad_simulations.schema_packages.basis_set.BasisSetContainer',
},
{
'm_def': 'nomad_simulations.schema_packages.basis_set.BasisSetContainer',
'basis_set_components': [
{
'm_def': 'nomad_simulations.schema_packages.basis_set.APWPlaneWaveBasisSet',
'cutoff_energy': 500.0,
},
]
],
},
{
'm_def': 'nomad_simulations.schema_packages.basis_set.BasisSetContainer',
'basis_set_components': [
{
'm_def': 'nomad_simulations.schema_packages.basis_set.APWPlaneWaveBasisSet',
'cutoff_energy': 500.0,
},
{
'm_def': 'nomad_simulations.schema_packages.basis_set.MuffinTinRegion',
'radius': 1.823,
'species_scope': ['/data/model_system/0/cell/0/atoms_state/0'],
'radius': 1.0,
'l_max': 2,
'l_channels': [
{
Expand Down Expand Up @@ -450,7 +459,7 @@ def k_space_simulation() -> Simulation:
},
],
},
]
],
},
{
'basis_set_components': [
Expand Down
34 changes: 32 additions & 2 deletions tests/test_basis_set.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
from typing import Any, Optional

from nomad_simulations.schema_packages.atoms_state import AtomsState
from nomad_simulations.schema_packages.general import Simulation
from nomad_simulations.schema_packages.model_method import ModelMethod
from nomad_simulations.schema_packages.model_system import AtomicCell, ModelSystem
import pytest

from nomad.datamodel.datamodel import EntryArchive
from . import logger
from nomad.units import ureg
import numpy as np
Expand Down Expand Up @@ -56,13 +62,37 @@ def test_cutoff_failure():
[
(0, {}, None),
(1, {}, 500.0),
(2, {'H': {'r': 1, 'l_max': 2, 'orb_type': ['apw']}}, 500.0),
(
2,
{
'/data/model_system/0/cell/0/atoms_state/0': {
'r': 1,
'l_max': 2,
'orb_type': ['apw'],
}
},
500.0,
),
],
)
def test_full_apw(
ref_index: int, species_def: dict[str, dict[str, Any]], cutoff: Optional[float]
):
"""Test the composite structure of APW basis sets."""
entry = EntryArchive(
data=Simulation(
model_system=[
ModelSystem(
cell=[AtomicCell(atoms_state=[AtomsState(chemical_symbol='H')])]
)
],
model_method=[ModelMethod(numerical_settings=[])],
)
)

numerical_settings = entry.data.model_method[0].numerical_settings
numerical_settings.append(generate_apw(species_def, cutoff=cutoff))

assert (
generate_apw(species_def, cutoff=cutoff).m_to_dict() == refs_apw[ref_index]
numerical_settings[0].m_to_dict() == refs_apw[ref_index]
) # TODO: add normalization?

0 comments on commit 713b904

Please sign in to comment.