Skip to content

Commit

Permalink
working on CG schema
Browse files Browse the repository at this point in the history
  • Loading branch information
Bernadette-Mohr committed Oct 14, 2024
1 parent beb7a8b commit 9508fd9
Show file tree
Hide file tree
Showing 2 changed files with 192 additions and 1,401 deletions.
188 changes: 94 additions & 94 deletions src/nomad_simulations/schema_packages/model_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from structlog.stdlib import BoundLogger

from nomad_simulations.schema_packages.atoms_state import AtomsState
from nomad_simulations.schema_packages.particles_state import ParticlesState
from nomad_simulations.schema_packages.particles_state import Particles, ParticlesState
from nomad_simulations.schema_packages.utils import (
get_sibling_section,
is_not_representative,
Expand Down Expand Up @@ -553,99 +553,99 @@ def is_equal_cell(self, other) -> bool:
return False
return True

# def get_chemical_symbols(self, logger: 'BoundLogger') -> list[str]:
# """
# Get the chemical symbols of the atoms in the atomic cell. These are defined on `atoms_state[*].chemical_symbol`.

# Args:
# logger (BoundLogger): The logger to log messages.

# Returns:
# list: The list of chemical symbols of the atoms in the atomic cell.
# """
# if not self.atoms_state:
# return []

# chemical_symbols = []
# for atom_state in self.atoms_state:
# if not atom_state.chemical_symbol:
# logger.warning('Could not find `AtomsState[*].chemical_symbol`.')
# return []
# chemical_symbols.append(atom_state.chemical_symbol)
# return chemical_symbols

# def to_ase_atoms(self, logger: 'BoundLogger') -> Optional[ase.Atoms]:
# """
# Generates an ASE Atoms object with the most basic information from the parsed `AtomicCell`
# section (labels, periodic_boundary_conditions, positions, and lattice_vectors).

# Args:
# logger (BoundLogger): The logger to log messages.

# Returns:
# (Optional[ase.Atoms]): The ASE Atoms object with the basic information from the `AtomicCell`.
# """
# # Initialize ase.Atoms object with labels
# atoms_labels = self.get_chemical_symbols(logger=logger)
# ase_atoms = ase.Atoms(symbols=atoms_labels)

# # PBC
# if self.periodic_boundary_conditions is None:
# logger.info(
# 'Could not find `AtomicCell.periodic_boundary_conditions`. They will be set to [False, False, False].'
# )
# self.periodic_boundary_conditions = [False, False, False]
# ase_atoms.set_pbc(pbc=self.periodic_boundary_conditions)

# # Lattice vectors
# if self.lattice_vectors is not None:
# ase_atoms.set_cell(cell=self.lattice_vectors.to('angstrom').magnitude)
# else:
# logger.info('Could not find `AtomicCell.lattice_vectors`.')

# # Positions
# if self.positions is not None:
# if len(self.positions) != len(self.atoms_state):
# logger.error(
# 'Length of `AtomicCell.positions` does not coincide with the length of the `AtomicCell.atoms_state`.'
# )
# return None
# ase_atoms.set_positions(
# newpositions=self.positions.to('angstrom').magnitude
# )
# else:
# logger.warning('Could not find `AtomicCell.positions`.')
# return None

# return ase_atoms

# def from_ase_atoms(self, ase_atoms: ase.Atoms, logger: 'BoundLogger') -> None:
# """
# Parses the information from an ASE Atoms object to the `AtomicCell` section.

# Args:
# ase_atoms (ase.Atoms): The ASE Atoms object to parse.
# logger (BoundLogger): The logger to log messages.
# """
# # `AtomsState[*].chemical_symbol`
# for symbol in ase_atoms.get_chemical_symbols():
# atom_state = AtomsState(chemical_symbol=symbol)
# self.atoms_state.append(atom_state)

# # `periodic_boundary_conditions`
# self.periodic_boundary_conditions = ase_atoms.get_pbc()

# # `lattice_vectors`
# cell = ase_atoms.get_cell()
# self.lattice_vectors = ase.geometry.complete_cell(cell) * ureg('angstrom')

# # `positions`
# positions = ase_atoms.get_positions()
# if (
# not positions.tolist()
# ): # ASE assigns a shape=(0, 3) array if no positions are found
# return None
# self.positions = positions * ureg('angstrom')
def get_particle_types(self, logger: 'BoundLogger') -> list[str]:
"""
Get the chemical symbols of the atoms in the atomic cell. These are defined on `atoms_state[*].chemical_symbol`.
Args:
logger (BoundLogger): The logger to log messages.
Returns:
list: The list of chemical symbols of the atoms in the atomic cell.
"""
if not self.particles_state:
return []

particle_labels = []
for particle_state in self.particles_state:
if not particle_state.particle_type:
logger.warning('Could not find `ParticlesState[*].particle_type`.')
return []
particle_labels.append(particle_state.particle_type)
return particle_labels

def to_particles(self, logger: 'BoundLogger') -> Optional[Particles]:
"""
Generates a Particles object with the most basic information from the parsed `ParticleCell`
section (labels, periodic_boundary_conditions, positions, and lattice_vectors).
Args:
logger (BoundLogger): The logger to log messages.
Returns:
(Optional[Particles]): The Partilces object with the basic information from the `ParticleCell`.
"""
# Initialize Partilces object with labels
particle_labels = self.get_particle_types(logger=logger)
particles = Particles(symbols=particle_labels)

# PBC
if self.periodic_boundary_conditions is None:
logger.info(
'Could not find `ParticleCell.periodic_boundary_conditions`. They will be set to [False, False, False].'
)
self.periodic_boundary_conditions = [False, False, False]
particles.set_pbc(pbc=self.periodic_boundary_conditions)

# Lattice vectors
if self.lattice_vectors is not None:
particles.set_cell(cell=self.lattice_vectors.to('angstrom').magnitude)
else:
logger.info('Could not find `ParticleCell.lattice_vectors`.')

# Positions
if self.positions is not None:
if len(self.positions) != len(self.particles_state):
logger.error(
'Length of `ParticleCell.positions` does not coincide with the length of the `ParticleCell.particles_state`.'
)
return None
particles.set_positions(
newpositions=self.positions.to('angstrom').magnitude
)
else:
logger.warning('Could not find `ParticleCell.positions`.')
return None

return particles

def from_particles(self, particles: Particles, logger: 'BoundLogger') -> None:
"""
Parses the information from a Particles object to the `ParticlesCell` section.
Args:
particles (Particles): The Particles object to parse.
logger (BoundLogger): The logger to log messages.
"""
# `ParticlesState[*].particles_type`
for label in particles.get_particle_types():
particle_state = ParticlesState(particle_type=label)
self.particles_state.append(particle_state)

# `periodic_boundary_conditions`
self.periodic_boundary_conditions = particles.get_pbc()

# `lattice_vectors`
cell = particles.get_cell()
self.lattice_vectors = ase.geometry.complete_cell(cell) * ureg('angstrom')

# `positions`
positions = particles.get_positions()
if (
not positions.tolist()
): # ASE assigns a shape=(0, 3) array if no positions are found
return None
self.positions = positions * ureg('angstrom')

def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None:
super().normalize(archive, logger)
Expand Down
Loading

0 comments on commit 9508fd9

Please sign in to comment.