From 2ee441da5c4bff21f3492c0d96dc4fb203cbc629 Mon Sep 17 00:00:00 2001 From: Alvin Noe Ladines Date: Sat, 30 Dec 2023 23:35:58 +0100 Subject: [PATCH] Resolve section class from archive --- systemnormalizer/normalizer.py | 43 +++++++++++++++++++--------------- 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/systemnormalizer/normalizer.py b/systemnormalizer/normalizer.py index 6c7a437..ffeb73d 100644 --- a/systemnormalizer/normalizer.py +++ b/systemnormalizer/normalizer.py @@ -19,7 +19,6 @@ from typing import Any, Dict from nptyping import NDArray import ase -from ase import Atoms import numpy as np import json import re @@ -30,10 +29,8 @@ from nomad.atomutils import Formula from nomad.units import ureg from nomad import utils, config -from runschema.system import ( - Atoms, Symmetry, SpringerMaterial, Prototype) -from .normalizer import SystemBasedNormalizer +from nomad.normalizing.normalizer import SystemBasedNormalizer # use a regular expression to check atom labels; expression is build from list of # all labels sorted desc to find Br and not B when searching for Br. @@ -100,8 +97,9 @@ def normalize_system(self, system, is_representative) -> bool: self.logger.error('section_run is not present.') return False + atoms_cls = system.m_def.all_sub_sections['atoms'].sub_section.section_cls if system.atoms is None: - system.m_create(Atoms) + system.atoms = atoms_cls() def get_value(quantity_def, default: Any = None, numpy: bool = True, source: Any = None) -> Any: try: @@ -121,11 +119,11 @@ def get_value(quantity_def, default: Any = None, numpy: bool = True, source: Any system.is_representative = is_representative # analyze atoms labels - atom_labels = get_value(Atoms.labels, numpy=False, source=system.atoms) + atom_labels = get_value(atoms_cls.labels, numpy=False, source=system.atoms) if atom_labels is not None: atom_labels = normalized_atom_labels(atom_labels) - atom_species = get_value(Atoms.species, numpy=False, source=system.atoms) + atom_species = get_value(atoms_cls.species, numpy=False, source=system.atoms) if atom_labels is None and atom_species is None: self.logger.warn('system has neither atom species nor labels') return False @@ -167,7 +165,7 @@ def get_value(quantity_def, default: Any = None, numpy: bool = True, source: Any system.atoms.species = atom_species # periodic boundary conditions - pbc = get_value(Atoms.periodic, numpy=False, source=system.atoms) + pbc = get_value(atoms_cls.periodic, numpy=False, source=system.atoms) if pbc is None: pbc = [False, False, False] self.logger.warning('missing configuration_periodic_dimensions') @@ -189,7 +187,7 @@ def get_value(quantity_def, default: Any = None, numpy: bool = True, source: Any self.logger.error('could not extract chemical formula', exc_info=e, error=str(e)) # positions - atom_positions = get_value(Atoms.positions, numpy=True, source=system.atoms) + atom_positions = get_value(atoms_cls.positions, numpy=True, source=system.atoms) if atom_positions is None or len(atom_positions) == 0: self.logger.warning('no atom positions, skip further system analysis') return False @@ -206,7 +204,7 @@ def get_value(quantity_def, default: Any = None, numpy: bool = True, source: Any return False # lattice vectors - lattice_vectors = get_value(Atoms.lattice_vectors, numpy=True, source=system.atoms) + lattice_vectors = get_value(atoms_cls.lattice_vectors, numpy=True, source=system.atoms) if lattice_vectors is None: if any(pbc): self.logger.error('no lattice vectors but periodicity', pbc=pbc) @@ -220,7 +218,7 @@ def get_value(quantity_def, default: Any = None, numpy: bool = True, source: Any # reciprocal lattice vectors lattice_vectors_reciprocal = get_value( - Atoms.lattice_vectors_reciprocal, numpy=True, source=system.atoms) + atoms_cls.lattice_vectors_reciprocal, numpy=True, source=system.atoms) if lattice_vectors_reciprocal is None and lattice_vectors is not None: system.atoms.lattice_vectors_reciprocal = 2 * np.pi * atomutils.reciprocal_cell(lattice_vectors.magnitude) # there is also a get_reciprocal_cell method in ase @@ -252,7 +250,7 @@ def get_value(quantity_def, default: Any = None, numpy: bool = True, source: Any return True - def system_type_analysis(self, atoms: Atoms) -> None: + def system_type_analysis(self, atoms: ase.Atoms) -> None: ''' Determine the system type with MatID. Write the system type to the entry_archive. @@ -347,8 +345,9 @@ def symmetry_analysis(self, system, atoms: ase.Atoms) -> None: # Write data extracted from MatID's symmetry analysis to the # representative section_system. - - sec_symmetry = system.m_create(Symmetry) + # symmetry_cls = system.m_def.all_sub_sections['symmetry'].sub_section.section_cls + sec_symmetry = system.m_def.all_sub_sections['symmetry'].sub_section.section_cls() + system.symmetry.append(sec_symmetry) sec_symmetry.m_cache["symmetry_analyzer"] = symm sec_symmetry.symmetry_method = 'MatID (spg)' @@ -362,21 +361,25 @@ def symmetry_analysis(self, system, atoms: ase.Atoms) -> None: sec_symmetry.origin_shift = origin_shift sec_symmetry.transformation_matrix = transform - sec_std = sec_symmetry.m_create(Atoms, Symmetry.system_std) + atoms_cls = system.m_def.all_sub_sections['atoms'].sub_section.section_cls + sec_std = atoms_cls() + sec_symmetry.system_std.append(sec_std) sec_std.lattice_vectors = conv_cell * ureg.angstrom sec_std.positions = conv_pos sec_std.atomic_numbers = conv_num sec_std.wyckoff_letters = conv_wyckoff sec_std.equivalent_atoms = conv_equivalent_atoms - sec_prim = sec_symmetry.m_create(Atoms, Symmetry.system_primitive) + sec_prim = atoms_cls() + sec_symmetry.system_primitive.append(sec_prim) sec_prim.lattice_vectors = prim_cell * ureg.angstrom sec_prim.positions = prim_pos sec_prim.atomic_numbers = prim_num sec_prim.wyckoff_letters = prim_wyckoff sec_prim.equivalent_atoms = prim_equivalent_atoms - sec_orig = sec_symmetry.m_create(Atoms, Symmetry.system_original) + sec_orig = atoms_cls() + sec_symmetry.system_original.append(sec_orig) sec_orig.wyckoff_letters = orig_wyckoff sec_orig.equivalent_atoms = orig_equivalent_atoms @@ -389,7 +392,8 @@ def springer_classification(self, atoms, space_group_number): idx = self.section_run.m_cache["representative_system_idx"] for material in springer_data.values(): - sec_springer_mat = self.section_run.system[idx].m_create(SpringerMaterial) + sec_springer_mat = self.section_run.system[idx].m_def.all_sub_sections['springer_material'].sub_section.section_cls() + self.section_run.system[idx].springer_material.append(sec_springer_mat) sec_springer_mat.id = material['spr_id'] sec_springer_mat.alphabetical_formula = material['spr_aformula'] @@ -443,7 +447,8 @@ def prototypes(self, system, atom_species: NDArray, wyckoffs: NDArray, spg_numbe protoDict.get("Pearsons Symbol", "-") ) idx = self.section_run.m_cache["representative_system_idx"] - sec_prototype = self.section_run.system[idx].m_create(Prototype) + sec_prototype = self.section_run.system[idx].m_def.all_sub_sections['prototype'].sub_section.section_cls() + self.section_run.system[idx].prototype.append(sec_prototype) sec_prototype.label = prototype_label sec_prototype.aflow_id = aflow_prototype_id sec_prototype.aflow_url = aflow_prototype_url