Skip to content

Commit

Permalink
Resolve section class from archive
Browse files Browse the repository at this point in the history
  • Loading branch information
ladinesa committed Dec 30, 2023
1 parent 995a7f6 commit 2ee441d
Showing 1 changed file with 24 additions and 19 deletions.
43 changes: 24 additions & 19 deletions systemnormalizer/normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from typing import Any, Dict
from nptyping import NDArray
import ase
from ase import Atoms
import numpy as np
import json
import re
Expand All @@ -30,10 +29,8 @@
from nomad.atomutils import Formula
from nomad.units import ureg
from nomad import utils, config
from runschema.system import (
Atoms, Symmetry, SpringerMaterial, Prototype)

from .normalizer import SystemBasedNormalizer
from nomad.normalizing.normalizer import SystemBasedNormalizer

# use a regular expression to check atom labels; expression is build from list of
# all labels sorted desc to find Br and not B when searching for Br.
Expand Down Expand Up @@ -100,8 +97,9 @@ def normalize_system(self, system, is_representative) -> bool:
self.logger.error('section_run is not present.')
return False

atoms_cls = system.m_def.all_sub_sections['atoms'].sub_section.section_cls
if system.atoms is None:
system.m_create(Atoms)
system.atoms = atoms_cls()

def get_value(quantity_def, default: Any = None, numpy: bool = True, source: Any = None) -> Any:
try:
Expand All @@ -121,11 +119,11 @@ def get_value(quantity_def, default: Any = None, numpy: bool = True, source: Any
system.is_representative = is_representative

# analyze atoms labels
atom_labels = get_value(Atoms.labels, numpy=False, source=system.atoms)
atom_labels = get_value(atoms_cls.labels, numpy=False, source=system.atoms)
if atom_labels is not None:
atom_labels = normalized_atom_labels(atom_labels)

atom_species = get_value(Atoms.species, numpy=False, source=system.atoms)
atom_species = get_value(atoms_cls.species, numpy=False, source=system.atoms)
if atom_labels is None and atom_species is None:
self.logger.warn('system has neither atom species nor labels')
return False
Expand Down Expand Up @@ -167,7 +165,7 @@ def get_value(quantity_def, default: Any = None, numpy: bool = True, source: Any
system.atoms.species = atom_species

# periodic boundary conditions
pbc = get_value(Atoms.periodic, numpy=False, source=system.atoms)
pbc = get_value(atoms_cls.periodic, numpy=False, source=system.atoms)
if pbc is None:
pbc = [False, False, False]
self.logger.warning('missing configuration_periodic_dimensions')
Expand All @@ -189,7 +187,7 @@ def get_value(quantity_def, default: Any = None, numpy: bool = True, source: Any
self.logger.error('could not extract chemical formula', exc_info=e, error=str(e))

# positions
atom_positions = get_value(Atoms.positions, numpy=True, source=system.atoms)
atom_positions = get_value(atoms_cls.positions, numpy=True, source=system.atoms)
if atom_positions is None or len(atom_positions) == 0:
self.logger.warning('no atom positions, skip further system analysis')
return False
Expand All @@ -206,7 +204,7 @@ def get_value(quantity_def, default: Any = None, numpy: bool = True, source: Any
return False

# lattice vectors
lattice_vectors = get_value(Atoms.lattice_vectors, numpy=True, source=system.atoms)
lattice_vectors = get_value(atoms_cls.lattice_vectors, numpy=True, source=system.atoms)
if lattice_vectors is None:
if any(pbc):
self.logger.error('no lattice vectors but periodicity', pbc=pbc)
Expand All @@ -220,7 +218,7 @@ def get_value(quantity_def, default: Any = None, numpy: bool = True, source: Any

# reciprocal lattice vectors
lattice_vectors_reciprocal = get_value(
Atoms.lattice_vectors_reciprocal, numpy=True, source=system.atoms)
atoms_cls.lattice_vectors_reciprocal, numpy=True, source=system.atoms)
if lattice_vectors_reciprocal is None and lattice_vectors is not None:
system.atoms.lattice_vectors_reciprocal = 2 * np.pi * atomutils.reciprocal_cell(lattice_vectors.magnitude) # there is also a get_reciprocal_cell method in ase

Expand Down Expand Up @@ -252,7 +250,7 @@ def get_value(quantity_def, default: Any = None, numpy: bool = True, source: Any

return True

def system_type_analysis(self, atoms: Atoms) -> None:
def system_type_analysis(self, atoms: ase.Atoms) -> None:
'''
Determine the system type with MatID. Write the system type to the
entry_archive.
Expand Down Expand Up @@ -347,8 +345,9 @@ def symmetry_analysis(self, system, atoms: ase.Atoms) -> None:

# Write data extracted from MatID's symmetry analysis to the
# representative section_system.

sec_symmetry = system.m_create(Symmetry)
# symmetry_cls = system.m_def.all_sub_sections['symmetry'].sub_section.section_cls
sec_symmetry = system.m_def.all_sub_sections['symmetry'].sub_section.section_cls()
system.symmetry.append(sec_symmetry)
sec_symmetry.m_cache["symmetry_analyzer"] = symm

sec_symmetry.symmetry_method = 'MatID (spg)'
Expand All @@ -362,21 +361,25 @@ def symmetry_analysis(self, system, atoms: ase.Atoms) -> None:
sec_symmetry.origin_shift = origin_shift
sec_symmetry.transformation_matrix = transform

sec_std = sec_symmetry.m_create(Atoms, Symmetry.system_std)
atoms_cls = system.m_def.all_sub_sections['atoms'].sub_section.section_cls
sec_std = atoms_cls()
sec_symmetry.system_std.append(sec_std)
sec_std.lattice_vectors = conv_cell * ureg.angstrom
sec_std.positions = conv_pos
sec_std.atomic_numbers = conv_num
sec_std.wyckoff_letters = conv_wyckoff
sec_std.equivalent_atoms = conv_equivalent_atoms

sec_prim = sec_symmetry.m_create(Atoms, Symmetry.system_primitive)
sec_prim = atoms_cls()
sec_symmetry.system_primitive.append(sec_prim)
sec_prim.lattice_vectors = prim_cell * ureg.angstrom
sec_prim.positions = prim_pos
sec_prim.atomic_numbers = prim_num
sec_prim.wyckoff_letters = prim_wyckoff
sec_prim.equivalent_atoms = prim_equivalent_atoms

sec_orig = sec_symmetry.m_create(Atoms, Symmetry.system_original)
sec_orig = atoms_cls()
sec_symmetry.system_original.append(sec_orig)
sec_orig.wyckoff_letters = orig_wyckoff
sec_orig.equivalent_atoms = orig_equivalent_atoms

Expand All @@ -389,7 +392,8 @@ def springer_classification(self, atoms, space_group_number):
idx = self.section_run.m_cache["representative_system_idx"]

for material in springer_data.values():
sec_springer_mat = self.section_run.system[idx].m_create(SpringerMaterial)
sec_springer_mat = self.section_run.system[idx].m_def.all_sub_sections['springer_material'].sub_section.section_cls()
self.section_run.system[idx].springer_material.append(sec_springer_mat)

sec_springer_mat.id = material['spr_id']
sec_springer_mat.alphabetical_formula = material['spr_aformula']
Expand Down Expand Up @@ -443,7 +447,8 @@ def prototypes(self, system, atom_species: NDArray, wyckoffs: NDArray, spg_numbe
protoDict.get("Pearsons Symbol", "-")
)
idx = self.section_run.m_cache["representative_system_idx"]
sec_prototype = self.section_run.system[idx].m_create(Prototype)
sec_prototype = self.section_run.system[idx].m_def.all_sub_sections['prototype'].sub_section.section_cls()
self.section_run.system[idx].prototype.append(sec_prototype)
sec_prototype.label = prototype_label
sec_prototype.aflow_id = aflow_prototype_id
sec_prototype.aflow_url = aflow_prototype_url
Expand Down

0 comments on commit 2ee441d

Please sign in to comment.