Skip to content

Commit

Permalink
working on CG schema
Browse files Browse the repository at this point in the history
  • Loading branch information
Bernadette-Mohr committed Oct 11, 2024
1 parent 4e2b559 commit beb7a8b
Show file tree
Hide file tree
Showing 2 changed files with 499 additions and 740 deletions.
162 changes: 162 additions & 0 deletions src/nomad_simulations/schema_packages/model_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from structlog.stdlib import BoundLogger

from nomad_simulations.schema_packages.atoms_state import AtomsState
from nomad_simulations.schema_packages.particles_state import ParticlesState
from nomad_simulations.schema_packages.utils import (
get_sibling_section,
is_not_representative,
Expand Down Expand Up @@ -492,6 +493,167 @@ def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None:
self.name = self.m_def.name if self.name is None else self.name


class ParticleCell(Cell):
"""
A base section used to specify the atomic cell information of a system.
"""

particles_state = SubSection(sub_section=ParticlesState.m_def, repeats=True)

n_particles = Quantity(
type=np.int32,
description="""
Number of atoms in the atomic cell.
""",
)

equivalent_particles = Quantity(
type=np.int32,
shape=['n_atoms'],
description="""
List of equivalent atoms as defined in `atoms`. If no equivalent atoms are found,
then the list is simply the index of each element, e.g.:
- [0, 1, 2, 3] all four atoms are non-equivalent.
- [0, 0, 0, 3] three equivalent atoms and one non-equivalent.
""",
)

def __init__(self, m_def: 'Section' = None, m_context: 'Context' = None, **kwargs):
super().__init__(m_def, m_context, **kwargs)
# Set the name of the section
self.name = self.m_def.name

def is_equal_cell(self, other) -> bool:
"""
Check if the atomic cell is equal to an`other` atomic cell by comparing the `positions` and
the `AtomsState[*].chemical_symbol`.
Args:
other: The other atomic cell to compare with.
Returns:
bool: True if the atomic cells are equal, False otherwise.
"""
if not isinstance(other, ParticleCell):
return False

# Compare positions using the parent sections's `__eq__` method
if not super().is_equal_cell(other=other):
return False

# Check that the `chemical_symbol` of the atoms in `cell_1` match with the ones in `cell_2`
check_positions = self._check_positions(
positions_1=self.positions, positions_2=other.positions
)
try:
for particle in check_positions:
type_1 = self.particles_state[particle[0]].particle_type
type_2 = other.particles_state[particle[1]].particle_type
if type_1 != type_2:
return False
except Exception:
return False
return True

# def get_chemical_symbols(self, logger: 'BoundLogger') -> list[str]:
# """
# Get the chemical symbols of the atoms in the atomic cell. These are defined on `atoms_state[*].chemical_symbol`.

# Args:
# logger (BoundLogger): The logger to log messages.

# Returns:
# list: The list of chemical symbols of the atoms in the atomic cell.
# """
# if not self.atoms_state:
# return []

# chemical_symbols = []
# for atom_state in self.atoms_state:
# if not atom_state.chemical_symbol:
# logger.warning('Could not find `AtomsState[*].chemical_symbol`.')
# return []
# chemical_symbols.append(atom_state.chemical_symbol)
# return chemical_symbols

# def to_ase_atoms(self, logger: 'BoundLogger') -> Optional[ase.Atoms]:
# """
# Generates an ASE Atoms object with the most basic information from the parsed `AtomicCell`
# section (labels, periodic_boundary_conditions, positions, and lattice_vectors).

# Args:
# logger (BoundLogger): The logger to log messages.

# Returns:
# (Optional[ase.Atoms]): The ASE Atoms object with the basic information from the `AtomicCell`.
# """
# # Initialize ase.Atoms object with labels
# atoms_labels = self.get_chemical_symbols(logger=logger)
# ase_atoms = ase.Atoms(symbols=atoms_labels)

# # PBC
# if self.periodic_boundary_conditions is None:
# logger.info(
# 'Could not find `AtomicCell.periodic_boundary_conditions`. They will be set to [False, False, False].'
# )
# self.periodic_boundary_conditions = [False, False, False]
# ase_atoms.set_pbc(pbc=self.periodic_boundary_conditions)

# # Lattice vectors
# if self.lattice_vectors is not None:
# ase_atoms.set_cell(cell=self.lattice_vectors.to('angstrom').magnitude)
# else:
# logger.info('Could not find `AtomicCell.lattice_vectors`.')

# # Positions
# if self.positions is not None:
# if len(self.positions) != len(self.atoms_state):
# logger.error(
# 'Length of `AtomicCell.positions` does not coincide with the length of the `AtomicCell.atoms_state`.'
# )
# return None
# ase_atoms.set_positions(
# newpositions=self.positions.to('angstrom').magnitude
# )
# else:
# logger.warning('Could not find `AtomicCell.positions`.')
# return None

# return ase_atoms

# def from_ase_atoms(self, ase_atoms: ase.Atoms, logger: 'BoundLogger') -> None:
# """
# Parses the information from an ASE Atoms object to the `AtomicCell` section.

# Args:
# ase_atoms (ase.Atoms): The ASE Atoms object to parse.
# logger (BoundLogger): The logger to log messages.
# """
# # `AtomsState[*].chemical_symbol`
# for symbol in ase_atoms.get_chemical_symbols():
# atom_state = AtomsState(chemical_symbol=symbol)
# self.atoms_state.append(atom_state)

# # `periodic_boundary_conditions`
# self.periodic_boundary_conditions = ase_atoms.get_pbc()

# # `lattice_vectors`
# cell = ase_atoms.get_cell()
# self.lattice_vectors = ase.geometry.complete_cell(cell) * ureg('angstrom')

# # `positions`
# positions = ase_atoms.get_positions()
# if (
# not positions.tolist()
# ): # ASE assigns a shape=(0, 3) array if no positions are found
# return None
# self.positions = positions * ureg('angstrom')

def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None:
super().normalize(archive, logger)

# Set the name of the section
self.name = self.m_def.name if self.name is None else self.name


class Symmetry(ArchiveSection):
"""
A base section used to specify the symmetry of the `AtomicCell`.
Expand Down
Loading

0 comments on commit beb7a8b

Please sign in to comment.