Skip to content

Commit

Permalink
Polished after meeting with Joe and Nathan
Browse files Browse the repository at this point in the history
  • Loading branch information
JosePizarro3 committed Jan 30, 2024
1 parent d0f3260 commit f49b246
Showing 1 changed file with 52 additions and 80 deletions.
132 changes: 52 additions & 80 deletions simulationdataschema/model_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,16 +97,9 @@ class AtomicCell(GeometricSpace):
""",
)

n_atoms = Quantity(
type=np.int32,
description="""
The total number of atoms in the system.
""",
)

labels = Quantity(
type=str,
shape=["n_atoms"],
shape=["*"],
description="""
List containing the labels of the atomic species in the system at the different positions
of the structure. It refers to a chemical element as defined in the periodic table,
Expand All @@ -116,15 +109,15 @@ class AtomicCell(GeometricSpace):

atomic_numbers = Quantity(
type=np.int32,
shape=["n_atoms"],
shape=["*"],
description="""
List of atomic numbers Z. This quantity is equivalent to `labels`.
""",
)

positions = Quantity(
type=np.float64,
shape=["n_atoms", 3],
shape=["*", 3],
unit="meter",
description="""
Positions of all the atoms in Cartesian coordinates.
Expand Down Expand Up @@ -162,7 +155,7 @@ class AtomicCell(GeometricSpace):

velocities = Quantity(
type=np.float64,
shape=["n_atoms", 3],
shape=["*", 3],
unit="meter / second",
description="""
Velocities of the atoms. It is the change in cartesian coordinates of the atom position
Expand All @@ -183,7 +176,7 @@ class AtomicCell(GeometricSpace):

equivalent_atoms = Quantity(
type=np.int32,
shape=["n_atoms"],
shape=["*"],
description="""
List of equivalent atoms as defined in `labels`. If no equivalent atoms are found,
then the list is simply the index of each element, e.g.:
Expand All @@ -194,7 +187,7 @@ class AtomicCell(GeometricSpace):

wyckoff_letters = Quantity(
type=str,
shape=["n_atoms"],
shape=["*"],
# TODO improve description
description="""
Wyckoff letters associated with each atom position.
Expand Down Expand Up @@ -272,12 +265,9 @@ def normalize(self, archive, logger):
)
return
self.labels = atomic_labels
self.n_atoms = len(atomic_labels)

# We will use ASE Atoms functionalities to extract information about the AtomicCell
ase_atoms = self.to_ase_atoms(logger)
# Store temporarily the ase.Atoms object to use in other ModelSystem section normalizers
self.m_cache["ase_atoms"] = ase_atoms

# Atomic numbers
if atomic_labels is not None and atomic_numbers is None:
Expand Down Expand Up @@ -415,46 +405,6 @@ def resolve_bulk_symmetry(self, original_atomic_cell, logger):
symmetry_analyzer = SymmetryAnalyzer(
ase_atoms, symmetry_tol=config.normalize.symmetry_tolerance
)

symmetry["bravais_lattice"] = symmetry_analyzer.get_bravais_lattice()
symmetry["hall_symbol"] = symmetry_analyzer.get_hall_symbol()
symmetry["point_group_symbol"] = symmetry_analyzer.get_point_group()
symmetry["space_group_number"] = symmetry_analyzer.get_space_group_number()
symmetry[
"space_group_symbol"
] = symmetry_analyzer.get_space_group_international_short()
symmetry["origin_shift"] = symmetry_analyzer._get_spglib_origin_shift()
symmetry[
"transformation_matrix"
] = symmetry_analyzer._get_spglib_transformation_matrix()

# Originally parsed cell
original_wyckoff = symmetry_analyzer.get_wyckoff_letters_original()
original_equivalent_atoms = (
symmetry_analyzer.get_equivalent_atoms_original()
)

# Primitive cell
primitive_wyckoff = symmetry_analyzer.get_wyckoff_letters_primitive()
primitive_equivalent_atoms = (
symmetry_analyzer.get_equivalent_atoms_primitive()
)
primitive_sys = symmetry_analyzer.get_primitive_system()
primitive_pos = primitive_sys.get_scaled_positions()
primitive_cell = primitive_sys.get_cell()
primitive_num = primitive_sys.get_atomic_numbers()
primitive_labels = primitive_sys.get_chemical_symbols()

# Standarized (or conventional) cell
standard_wyckoff = symmetry_analyzer.get_wyckoff_letters_conventional()
standard_equivalent_atoms = (
symmetry_analyzer.get_equivalent_atoms_conventional()
)
standard_sys = symmetry_analyzer.get_conventional_system()
standard_pos = standard_sys.get_scaled_positions()
standard_cell = standard_sys.get_cell()
standard_num = standard_sys.get_atomic_numbers()
standard_labels = standard_sys.get_chemical_symbols()
except ValueError as e:
logger.debug(
"Symmetry analysis with MatID is not available.", details=str(e)
Expand All @@ -464,13 +414,33 @@ def resolve_bulk_symmetry(self, original_atomic_cell, logger):
logger.warning("Symmetry analysis with MatID failed.", exc_info=e)
return

# We store symmetry_analyzer info in a dictionary
symmetry["bravais_lattice"] = symmetry_analyzer.get_bravais_lattice()
symmetry["hall_symbol"] = symmetry_analyzer.get_hall_symbol()
symmetry["point_group_symbol"] = symmetry_analyzer.get_point_group()
symmetry["space_group_number"] = symmetry_analyzer.get_space_group_number()
symmetry[
"space_group_symbol"
] = symmetry_analyzer.get_space_group_international_short()
symmetry["origin_shift"] = symmetry_analyzer._get_spglib_origin_shift()
symmetry[
"transformation_matrix"
] = symmetry_analyzer._get_spglib_transformation_matrix()

# Populating the originally parsed AtomicCell wyckoff_letters and equivalent_atoms information
# sec_original_atoms = self.atomic_cell[0]
# sec_original_atoms = original_atomic_cell
original_wyckoff = symmetry_analyzer.get_wyckoff_letters_original()
original_equivalent_atoms = symmetry_analyzer.get_equivalent_atoms_original()
original_atomic_cell.wyckoff_letters = original_wyckoff
original_atomic_cell.equivalent_particles = original_equivalent_atoms

# Populating the primitive AtomicCell information
primitive_wyckoff = symmetry_analyzer.get_wyckoff_letters_primitive()
primitive_equivalent_atoms = symmetry_analyzer.get_equivalent_atoms_primitive()
primitive_sys = symmetry_analyzer.get_primitive_system()
primitive_pos = primitive_sys.get_scaled_positions()
primitive_cell = primitive_sys.get_cell()
primitive_num = primitive_sys.get_atomic_numbers()
primitive_labels = primitive_sys.get_chemical_symbols()
primitive_atomic_cell = AtomicCell()
primitive_atomic_cell.name = "primitive"
primitive_atomic_cell.lattice_vectors = primitive_cell * ureg.angstrom
Expand All @@ -483,6 +453,15 @@ def resolve_bulk_symmetry(self, original_atomic_cell, logger):
primitive_atomic_cell.get_geometric_space_for_atomic_cell(logger)

# Populating the standarized Atoms information
standard_wyckoff = symmetry_analyzer.get_wyckoff_letters_conventional()
standard_equivalent_atoms = (
symmetry_analyzer.get_equivalent_atoms_conventional()
)
standard_sys = symmetry_analyzer.get_conventional_system()
standard_pos = standard_sys.get_scaled_positions()
standard_cell = standard_sys.get_cell()
standard_num = standard_sys.get_atomic_numbers()
standard_labels = standard_sys.get_chemical_symbols()
standard_atomic_cell = AtomicCell()
standard_atomic_cell.name = "standard"
standard_atomic_cell.lattice_vectors = standard_cell * ureg.angstrom
Expand Down Expand Up @@ -705,25 +684,21 @@ class ModelSystem(System):
description="""
Type of the system (atom, bulk, surface, etc.) which is determined by the normalizer.
""",
a_eln=ELNAnnotation(component="EnumEditQuantity"),
)

dimensionality = Quantity(
type=MEnum("0D", "1D", "2D", "3D", "unavailable"),
type=np.int32,
description="""
Dimensionality of the system. For atomistic systems this is automatically evaluated
by using the topology-scaling algorithm:
Dimensionality of the system: 0, 1, 2, or 3 dimensions. For atomistic systems this
is automatically evaluated by using the topology-scaling algorithm:
https://doi.org/10.1103/PhysRevLett.118.106101.
| Value | Description |
| --------- | ----------------------- |
| `'0D'` | Points in the space |
| `'1D'` | Periodic in one dimension |
| `'2D'` | Periodic in two dimensions |
| `'3D'` | Periodic in three dimensions |
""",
a_eln=ELNAnnotation(component="NumberEditQuantity"),
)

# TODO improve on the definition and usage
is_representative = Quantity(
type=bool,
default=False,
Expand Down Expand Up @@ -758,14 +733,15 @@ class ModelSystem(System):
""",
)

# TODO what about `branch_index`?
# TODO what about `branch_index` or `branch_depth`?
tree_index = Quantity(
type=np.int32,
description="""
Index refering to the depth of a branch in the system tree.
""",
)

# TODO add method to resolve labels and positions from the parent AtomicCell
atom_indices = Quantity(
type=np.int32,
shape=["*"],
Expand Down Expand Up @@ -822,22 +798,22 @@ def resolve_system_type_and_dimensionality(self, ase_atoms):
classification = type(cls)
if classification == Class3D:
system_type = "bulk"
dimensionality = "3D"
dimensionality = 3
elif classification == Atom:
system_type = "atom"
dimensionality = "3D"
dimensionality = 0
elif classification == Class0D:
system_type = "molecule / cluster"
dimensionality = "0D"
dimensionality = 0
elif classification == Class1D:
system_type = "1D"
dimensionality = "1D"
dimensionality = 1
elif classification == Surface:
system_type = "surface"
dimensionality = "2D"
dimensionality = 2
elif classification == Material2D or classification == Class2D:
system_type = "2D"
dimensionality = "2D"
dimensionality = 2
else:
self.logger.info(
"ModelSystem.type and dimensionality analysis not run due to large system size."
Expand All @@ -850,7 +826,6 @@ def normalize(self, archive, logger):
self.logger = logger

# We don't need to normalize if the system is not representative
# self.is_representative = True
if not self.is_representative:
return

Expand All @@ -862,7 +837,7 @@ def normalize(self, archive, logger):
)
return
self.atomic_cell[0].name = "original"
ase_atoms = self.atomic_cell[0].m_cache.get("ase_atoms")
ase_atoms = self.atomic_cell[0].to_ase_atoms(logger)
if not ase_atoms:
return

Expand All @@ -871,9 +846,6 @@ def normalize(self, archive, logger):
original_atom_positions = self.atomic_cell[0].positions
if original_atom_positions is not None:
self.type = "unavailable" if not self.type else self.type
self.dimensionality = (
"unavailable" if not self.dimensionality else self.dimensionality
)
(
self.type,
self.dimensionality,
Expand Down

0 comments on commit f49b246

Please sign in to comment.