From 9b61bdf3e97f8b8401101bd6bfaaae73885d2411 Mon Sep 17 00:00:00 2001 From: ndaelman Date: Fri, 16 Aug 2024 21:18:47 +0200 Subject: [PATCH] Add test template for the APW structure --- .../schema_packages/basis_set.py | 5 +- tests/conftest.py | 68 +++++++++++++++++++ tests/test_basis_set.py | 55 ++++++++++----- 3 files changed, 107 insertions(+), 21 deletions(-) diff --git a/src/nomad_simulations/schema_packages/basis_set.py b/src/nomad_simulations/schema_packages/basis_set.py index 225fb8f1..6d6e8b86 100644 --- a/src/nomad_simulations/schema_packages/basis_set.py +++ b/src/nomad_simulations/schema_packages/basis_set.py @@ -56,8 +56,6 @@ class BasisSet(ArchiveSection): def __init__(self, m_def: 'Section' = None, m_context: 'Context' = None, **kwargs): super().__init__(m_def, m_context, **kwargs) - # Set the name of the section - self.name = self.m_def.name class PlaneWaveBasisSet(BasisSet, Mesh): @@ -235,6 +233,7 @@ class APWOrbital(APWBaseOrbital): """, ) + class APWLocalOrbital(APWBaseOrbital): """ Implementation of `APWWavefunction` capturing a local orbital extending a foundational APW basis set. @@ -286,7 +285,7 @@ class APWLChannel(BasisSet): description=""" Angular momentum quantum number of the local orbital. """, - ) + ) # TODO: add `l` as a quantity n_wavefunctions = Quantity( type=np.int32, diff --git a/tests/conftest.py b/tests/conftest.py index c4a5a798..9cc1d4f3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -402,3 +402,71 @@ def k_line_path() -> KLinePathSettings: @pytest.fixture(scope='session') def k_space_simulation() -> Simulation: return generate_k_space_simulation() + + +apw = { + 'basis_set_components': [ + { + 'm_def': 'nomad_simulations.schema_packages.basis_set.APWPlaneWaveBasisSet', + 'cutoff_energy': 500.0, + }, + { + 'm_def': 'nomad_simulations.schema_packages.basis_set.MuffinTinRegion', + 'radius': 1.823, + 'l_max': 2, + 'l_channels': [ + { + 'name': 0, + 'orbitals': [ + { + 'm_def': 'nomad_simulations.schema_packages.basis_set.APWOrbital', + 'type': 'apw', + }, + { + 'm_def': 'nomad_simulations.schema_packages.basis_set.APWOrbital', + 'type': 'lapw', + }, + { + 'm_def': 'nomad_simulations.schema_packages.basis_set.APWLocalOrbital', + 'type': 'lo', + }, + ], + }, + { + 'name': 1, + 'orbitals': [ + { + 'm_def': 'nomad_simulations.schema_packages.basis_set.APWOrbital', + 'type': 'apw', + }, + { + 'm_def': 'nomad_simulations.schema_packages.basis_set.APWOrbital', + 'type': 'lapw', + }, + { + 'm_def': 'nomad_simulations.schema_packages.basis_set.APWLocalOrbital', + 'type': 'lo', + }, + ], + }, + { + 'name': 2, + 'orbitals': [ + { + 'm_def': 'nomad_simulations.schema_packages.basis_set.APWOrbital', + 'type': 'apw', + }, + { + 'm_def': 'nomad_simulations.schema_packages.basis_set.APWOrbital', + 'type': 'lapw', + }, + { + 'm_def': 'nomad_simulations.schema_packages.basis_set.APWLocalOrbital', + 'type': 'lo', + }, + ], + }, + ], + }, + ] +} diff --git a/tests/test_basis_set.py b/tests/test_basis_set.py index 15c7b14a..01bd1cdc 100644 --- a/tests/test_basis_set.py +++ b/tests/test_basis_set.py @@ -1,8 +1,11 @@ +import itertools from . import logger from nomad.units import ureg import numpy as np import pytest -from typing import Optional +from typing import Optional, Any + +from tests.conftest import apw from nomad_simulations.schema_packages.basis_set import ( APWBaseOrbital, @@ -47,19 +50,17 @@ def test_cutoff_failure(): assert pw.cutoff_fractional == 1 -@pytest.mark.skip(reason="This function is not meant to be tested directly") +@pytest.mark.skip(reason='This function is not meant to be tested directly') def generate_apw( - species: dict[str, int | APWBaseOrbital], - cutoff: Optional[float] = None + species: dict[str, dict[str, Any]], cutoff: Optional[float] = None ) -> BasisSetContainer: """ Generate a mock APW basis set with the following structure: . - ├── plane-wave basis set - └── muffin-tin regions - └── l-channels - ├── (orbitals) - │ └── wavefunctions + ├── 1 x plane-wave basis set + └── n x muffin-tin regions + └── l_max x l-channels + ├── orbitals └── local orbitals """ basis_set_components: list[BasisSet] = [] @@ -67,21 +68,39 @@ def generate_apw( pw = APWPlaneWaveBasisSet(cutoff_energy=cutoff) basis_set_components.append(pw) - mts: list[MuffinTinRegion] = [] - for sp in species: + for sp_name, sp in species.items(): l_max = sp['l_max'] mt = MuffinTinRegion( radius=sp['r'], l_max=l_max, l_channels=[ APWLChannel( - l=l, - orbitals=[APWOrbital(type=orb) for orb in sp['orb_type']] +\ - [APWLocalOrbital(type=lo) for lo in sp['lo_type']], - ) for l in range(l_max) - ] + name=l, + orbitals=list( + itertools.chain( + (APWOrbital(type=orb) for orb in sp['orb_type']), + (APWLocalOrbital(type=lo) for lo in sp['lo_type']), + ) + ), + ) + for l in range(l_max + 1) + ], ) - mts.append(mt) - basis_set_components.append(mts) + basis_set_components.append(mt) return BasisSetContainer(basis_set_components=basis_set_components) + + +def test_full_apw(): + ref_apw = generate_apw( + { + 'A': { + 'r': 1.823, + 'l_max': 2, + 'orb_type': ['apw', 'lapw'], + 'lo_type': ['lo'], + } + }, + cutoff=500, + ) + assert ref_apw.m_to_dict() == apw