Skip to content

Commit

Permalink
Improved utils function
Browse files Browse the repository at this point in the history
  • Loading branch information
JosePizarro3 committed Feb 9, 2024
1 parent 3662c20 commit 0bbc90b
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 52 deletions.
39 changes: 24 additions & 15 deletions simulationdataschema/model_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
from nomad.datamodel.metainfo.basesections import System, GeometricSpace
from nomad.datamodel.metainfo.annotations import ELNAnnotation

from .utils import get_sub_section_from_section_parent
from .utils import get_sibling_section


class AtomicCell(GeometricSpace):
Expand Down Expand Up @@ -109,6 +109,16 @@ class AtomicCell(GeometricSpace):
""",
)

velocities = Quantity(
type=np.float64,
shape=["*", 3],
unit="meter / second",
description="""
Velocities of the atoms. It is the change in cartesian coordinates of the atom position
with time.
""",
) # ? what about forces, stress

lattice_vectors = Quantity(
type=np.float64,
shape=[3, 3],
Expand Down Expand Up @@ -139,16 +149,6 @@ class AtomicCell(GeometricSpace):
""",
)

velocities = Quantity(
type=np.float64,
shape=["*", 3],
unit="meter / second",
description="""
Velocities of the atoms. It is the change in cartesian coordinates of the atom position
with time.
""",
)

supercell_matrix = Quantity(
type=np.int32,
shape=[3, 3],
Expand Down Expand Up @@ -185,6 +185,9 @@ def to_ase_atoms(self, logger: BoundLogger) -> Union[ase.Atoms, None]:
Generates an ASE Atoms object with the most basic information from the parsed `AtomicCell`
section (labels, periodic_boundary_conditions, positions, and lattice_vectors).
Args:
logger (BoundLogger): The logger to log messages.
Returns:
Union[ase.Atoms, None]: The ASE Atoms object with the basic information from the `AtomicCell`.
"""
Expand All @@ -199,7 +202,7 @@ def to_ase_atoms(self, logger: BoundLogger) -> Union[ase.Atoms, None]:
self.periodic_boundary_conditions = [False, False, False]
ase_atoms.set_pbc(self.periodic_boundary_conditions)

# Positions
# Positions (ensure they are parsed)
if self.positions is not None:
if len(self.positions) != len(self.labels):
logger.error(
Expand Down Expand Up @@ -375,6 +378,7 @@ def resolve_analyzed_atomic_cell(
Args:
symmetry_analyzer (SymmetryAnalyzer): The `SymmetryAnalyzer` object used to resolve.
cell_type (str): The type of cell to resolve, either 'primitive' or 'conventional'.
logger (BoundLogger): The logger to log messages.
Returns:
Union[AtomicCell, None]: The resolved `AtomicCell` section or None if the cell_type
Expand Down Expand Up @@ -418,6 +422,7 @@ def resolve_bulk_symmetry(
Args:
original_atomic_cell (AtomicCell): The `AtomicCell` section that the symmetry
uses to in MatID.SymmetryAnalyzer().
logger (BoundLogger): The logger to log messages.
Returns:
primitive_atomic_cell (AtomicCell): The primitive `AtomicCell` section.
conventional_atomic_cell (AtomicCell): The standarized `AtomicCell` section.
Expand Down Expand Up @@ -454,7 +459,7 @@ def resolve_bulk_symmetry(
original_wyckoff = symmetry_analyzer.get_wyckoff_letters_original()
original_equivalent_atoms = symmetry_analyzer.get_equivalent_atoms_original()
original_atomic_cell.wyckoff_letters = original_wyckoff
original_atomic_cell.equivalent_particles = original_equivalent_atoms
original_atomic_cell.equivalent_atoms = original_equivalent_atoms

# Populating the primitive AtomicCell information
primitive_atomic_cell = self.resolve_analyzed_atomic_cell(
Expand Down Expand Up @@ -498,7 +503,9 @@ def resolve_bulk_symmetry(
return primitive_atomic_cell, conventional_atomic_cell

def normalize(self, archive, logger):
atomic_cell = get_sub_section_from_section_parent(self, "atomic_cell", logger)
atomic_cell = get_sibling_section(
section=self, sibling_section_name="atomic_cell", logger=logger
)
if self.m_parent.type == "bulk":
# Adding the newly calculated primitive and conventional cells to the ModelSystem
(
Expand Down Expand Up @@ -585,7 +592,9 @@ def resolve_chemical_formulas(self, formula: Formula):
self.anonymous = formula.format("anonymous")

def normalize(self, archive, logger):
atomic_cell = get_sub_section_from_section_parent(self, "atomic_cell", logger)
atomic_cell = get_sibling_section(
section=self, sibling_section_name="atomic_cell", logger=logger
)
ase_atoms = atomic_cell.to_ase_atoms(logger)
formula = None
try:
Expand Down
39 changes: 24 additions & 15 deletions simulationdataschema/model_system_review.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
from nomad.datamodel.metainfo.basesections import System, GeometricSpace
from nomad.datamodel.metainfo.annotations import ELNAnnotation

from .utils import get_sub_section_from_section_parent
from .utils import get_sibling_section


class AtomicCell(GeometricSpace):
Expand Down Expand Up @@ -109,6 +109,16 @@ class AtomicCell(GeometricSpace):
""",
)

velocities = Quantity(
type=np.float64,
shape=["*", 3],
unit="meter / second",
description="""
Velocities of the atoms. It is the change in cartesian coordinates of the atom position
with time.
""",
) # ? what about forces, stress

lattice_vectors = Quantity(
type=np.float64,
shape=[3, 3],
Expand Down Expand Up @@ -139,16 +149,6 @@ class AtomicCell(GeometricSpace):
""",
)

velocities = Quantity(
type=np.float64,
shape=["*", 3],
unit="meter / second",
description="""
Velocities of the atoms. It is the change in cartesian coordinates of the atom position
with time.
""",
)

supercell_matrix = Quantity(
type=np.int32,
shape=[3, 3],
Expand Down Expand Up @@ -185,6 +185,9 @@ def to_ase_atoms(self, logger: BoundLogger) -> Union[ase.Atoms, None]:
Generates an ASE Atoms object with the most basic information from the parsed `AtomicCell`
section (labels, periodic_boundary_conditions, positions, and lattice_vectors).
Args:
logger (BoundLogger): The logger to log messages.
Returns:
Union[ase.Atoms, None]: The ASE Atoms object with the basic information from the `AtomicCell`.
"""
Expand All @@ -199,7 +202,7 @@ def to_ase_atoms(self, logger: BoundLogger) -> Union[ase.Atoms, None]:
self.periodic_boundary_conditions = [False, False, False]
ase_atoms.set_pbc(self.periodic_boundary_conditions)

# Positions
# Positions (ensure they are parsed)
if self.positions is not None:
if len(self.positions) != len(self.labels):
logger.error(
Expand Down Expand Up @@ -375,6 +378,7 @@ def resolve_analyzed_atomic_cell(
Args:
symmetry_analyzer (SymmetryAnalyzer): The `SymmetryAnalyzer` object used to resolve.
cell_type (str): The type of cell to resolve, either 'primitive' or 'conventional'.
logger (BoundLogger): The logger to log messages.
Returns:
Union[AtomicCell, None]: The resolved `AtomicCell` section or None if the cell_type
Expand Down Expand Up @@ -418,6 +422,7 @@ def resolve_bulk_symmetry(
Args:
original_atomic_cell (AtomicCell): The `AtomicCell` section that the symmetry
uses to in MatID.SymmetryAnalyzer().
logger (BoundLogger): The logger to log messages.
Returns:
primitive_atomic_cell (AtomicCell): The primitive `AtomicCell` section.
conventional_atomic_cell (AtomicCell): The standarized `AtomicCell` section.
Expand Down Expand Up @@ -454,7 +459,7 @@ def resolve_bulk_symmetry(
original_wyckoff = symmetry_analyzer.get_wyckoff_letters_original()
original_equivalent_atoms = symmetry_analyzer.get_equivalent_atoms_original()
original_atomic_cell.wyckoff_letters = original_wyckoff
original_atomic_cell.equivalent_particles = original_equivalent_atoms
original_atomic_cell.equivalent_atoms = original_equivalent_atoms

# Populating the primitive AtomicCell information
primitive_atomic_cell = self.resolve_analyzed_atomic_cell(
Expand Down Expand Up @@ -498,7 +503,9 @@ def resolve_bulk_symmetry(
return primitive_atomic_cell, conventional_atomic_cell

def normalize(self, archive, logger):
atomic_cell = get_sub_section_from_section_parent(self, "atomic_cell", logger)
atomic_cell = get_sibling_section(
section=self, sibling_section_name="atomic_cell", logger=logger
)
if self.m_parent.type == "bulk":
# Adding the newly calculated primitive and conventional cells to the ModelSystem
(
Expand Down Expand Up @@ -585,7 +592,9 @@ def resolve_chemical_formulas(self, formula: Formula):
self.anonymous = formula.format("anonymous")

def normalize(self, archive, logger):
atomic_cell = get_sub_section_from_section_parent(self, "atomic_cell", logger)
atomic_cell = get_sibling_section(
section=self, sibling_section_name="atomic_cell", logger=logger
)
ase_atoms = atomic_cell.to_ase_atoms(logger)
formula = None
try:
Expand Down
2 changes: 1 addition & 1 deletion simulationdataschema/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .utils import get_sub_section_from_section_parent
from .utils import get_sibling_section
52 changes: 31 additions & 21 deletions simulationdataschema/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,40 +17,50 @@
# limitations under the License.
#

from typing import Optional
from structlog.stdlib import BoundLogger

from nomad.datamodel.data import ArchiveSection


def get_sub_section_from_section_parent(
section: ArchiveSection, sub_section: str, logger
def get_sibling_section(
section: ArchiveSection, sibling_section_name: str, logger: BoundLogger
) -> ArchiveSection:
"""
Gets the sub_section of a section by performing a seesaw move, going to the parent of
section, and then going down to sub_section. Example, if section is `Symmetry`, and we
want to resolve `AtomicCell` (sub_section), this methods goes up to `ModelSystem` from `Symmetry`,
and then goes down to `AtomicCell`.
Gets the sibling section of a section by performing a seesaw move by going to the parent
of the section and then going down to the sibling section. This is used, e.g., to get
the `AtomicCell` section from the `Symmetry` section and by passing through the `ModelSystem`.
Example of the sections structure:
parent_section
|__ section
|__ sibling_section
If the sub_section is a list, it returns the first element of the list. If the sub_section is
a single section, it returns the section.
If the sibling_section is a list, it returns the first element of the list. If the sibling_section is
a single section, it returns the sibling_section itself.
Args:
section (ArchiveSection): The section to check for its parent and retrieve the sub_section.
sub_section (str): The name of the sub_section to retrieve from the parent.
section (ArchiveSection): The section to check for its parent and retrieve the sibling_section.
sibling_section (str): The name of the sibling_section to retrieve from the parent.
logger (BoundLogger): The logger to log messages.
Returns:
sub_section_sec (ArchiveSection): The sub_section to be returned.
sibling_section (ArchiveSection): The sibling_section to be returned.
"""
if section.m_parent is None:
logger.error("Could not find the parent of the section.")
logger.warning("Could not find the parent of the section.")
return
if not section.m_parent.m_xpath(sub_section):
logger.error("Could not find the section.m_parent.sub_section.")
if not section.m_parent.m_xpath(sibling_section_name):
logger.warning("Could not find the section.m_parent.sub_section.")
return
sub_section_sec = getattr(section.m_parent, sub_section)
if isinstance(sub_section_sec, list):
if len(sub_section_sec) == 0:
logger.error("The sub_section is empty.")
sibling_section = getattr(section.m_parent, sibling_section_name)
if isinstance(sibling_section, list):
if len(sibling_section) == 0:
logger.warning("The sub_section is empty.")
return
return sub_section_sec[0]
elif isinstance(sub_section_sec, ArchiveSection):
return sub_section_sec
return sibling_section[0] # ? extend for any section targeted as input
elif isinstance(sibling_section, ArchiveSection):
return sibling_section
return

0 comments on commit 0bbc90b

Please sign in to comment.