Skip to content

Commit

Permalink
Added functional programming for KLinePath.resolve_points
Browse files Browse the repository at this point in the history
Added testing for check_high_symmetry_path
  • Loading branch information
JosePizarro3 committed May 22, 2024
1 parent a650d88 commit a9bdac2
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 58 deletions.
102 changes: 45 additions & 57 deletions src/nomad_simulations/numerical_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import numpy as np
import pint
import itertools
from itertools import accumulate, tee, chain
from structlog.stdlib import BoundLogger
from typing import Optional, List, Tuple, Union, Dict
from ase.dft.kpoints import monkhorst_pack, get_monkhorst_pack_size_and_offset
Expand Down Expand Up @@ -536,6 +536,9 @@ def _check_high_symmetry_path(self, logger: BoundLogger) -> bool:
if (
self.high_symmetry_path_names is None
or self.high_symmetry_path_values is None
) or (
len(self.high_symmetry_path_names) == 0
or len(self.high_symmetry_path_values) == 0
):
logger.warning(
'Could not find `KLinePath.high_symmetry_path_names` or `KLinePath.high_symmetry_path_values`.'
Expand Down Expand Up @@ -575,12 +578,12 @@ def get_high_symmetry_path_norms(
return None
rlv = reciprocal_lattice_vectors.magnitude

def calc_norms(value_rlv, prev_value_rlv):
def calc_norms(
value_rlv: np.ndarray, prev_value_rlv: np.ndarray
) -> pint.Quantity:
value_tot_rlv = value_rlv - prev_value_rlv
return np.linalg.norm(value_tot_rlv) * reciprocal_lattice_vectors.u

from itertools import accumulate, tee

# Compute `rlv` projections
rlv_projections = list(
map(lambda value: value @ rlv, self.high_symmetry_path_values)
Expand All @@ -589,8 +592,7 @@ def calc_norms(value_rlv, prev_value_rlv):
# Create two iterators for the projections
rlv_projections_1, rlv_projections_2 = tee(rlv_projections)

# Initialize the previous value iterators and skip the first element in the second iterator
prev_value_rlv = np.array([0, 0, 0])
# Skip the first element in the second iterator
next(rlv_projections_2, None)

# Calculate the norms using accumulate
Expand All @@ -601,29 +603,6 @@ def calc_norms(value_rlv, prev_value_rlv):
)
return list(norms)

# # initializing the norms list (the first point has a norm of 0)
# high_symmetry_path_value_norms = [0.0 * reciprocal_lattice_vectors.u]
# # initializing the first point
# prev_value_norm = 0.0 * reciprocal_lattice_vectors.u
# prev_value_rlv = np.array([0, 0, 0])
# for i, value in enumerate(self.high_symmetry_path_values):
# if i == 0:
# continue
# value_rlv = value @ rlv
# value_tot_rlv = value_rlv - prev_value_rlv
# value_norm = (
# np.linalg.norm(value_tot_rlv) * reciprocal_lattice_vectors.u
# + prev_value_norm
# )

# # store in new path norms variable
# high_symmetry_path_value_norms.append(value_norm)

# # accumulate value vector and norm
# prev_value_rlv = value_rlv
# prev_value_norm = value_norm
# return high_symmetry_path_value_norms

def resolve_points(
self,
points_norm: Union[np.ndarray, List[float]],
Expand All @@ -647,6 +626,7 @@ def resolve_points(
'The `reciprocal_lattice_vectors` are not passed as an input.'
)
return None

# Check if `points_norm` is a list and convert it to a numpy array
if isinstance(points_norm, list):
points_norm = np.array(points_norm)
Expand All @@ -658,38 +638,46 @@ def resolve_points(
)
self.n_line_points = len(points_norm)

# Calculate the total norm of the path in order to find the closest indices in the list of `points_norm`
# Calculate the norms in the path and find the closest indices in points_norm to the high symmetry path norms
high_symmetry_path_value_norms = self.get_high_symmetry_path_norms(
reciprocal_lattice_vectors, logger
)
closest_indices = []
for i, norm in enumerate(high_symmetry_path_value_norms):
closest_idx = (np.abs(points_norm - norm.magnitude)).argmin()
closest_indices.append(closest_idx)

# Append the data in the new `points` in units of the `reciprocal_lattice_vectors`
points = []
for i, value in enumerate(self.high_symmetry_path_values):
if i == 0:
prev_value = value
prev_index = closest_indices[i]
continue
elif i == len(self.high_symmetry_path_values) - 1:
points.append(
np.linspace(
prev_value, value, num=closest_indices[i] - prev_index + 1
)
)
else:
# pop the last element as it appears repeated in the next segment
points.append(
np.linspace(
prev_value, value, num=closest_indices[i] - prev_index + 1
)[:-1]
closest_indices = list(
map(
lambda norm: (np.abs(points_norm - norm.magnitude)).argmin(),
high_symmetry_path_value_norms,
)
)

def linspace_segments(
prev_value: np.ndarray, value: np.ndarray, num: int
) -> np.ndarray:
return np.linspace(prev_value, value, num=num + 1)[:-1]

# Generate point segments using `map` and `linspace_segments`
points_segments = list(
map(
lambda i, value: linspace_segments(
self.high_symmetry_path_values[i - 1],
value,
closest_indices[i] - closest_indices[i - 1],
)
prev_value = value
prev_index = closest_indices[i]
new_points = list(itertools.chain(*points))
if i > 0
else np.array([]),
range(len(self.high_symmetry_path_values)),
self.high_symmetry_path_values,
)
)
# and handle the last segment to include all points
points_segments[-1] = np.linspace(
self.high_symmetry_path_values[-2],
self.high_symmetry_path_values[-1],
num=closest_indices[-1] - closest_indices[-2] + 1,
)

# Flatten the list of segments into a single list of points
new_points = list(chain.from_iterable(points_segments))

# And store this information in the `points` quantity
if self.points is not None:
logger.info('Overwriting `KLinePath.points` with the resolved points.')
Expand Down
27 changes: 26 additions & 1 deletion tests/test_numerical_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from nomad_simulations.numerical_settings import KMesh, KLinePath

from . import logger
from .conftest import generate_k_space_simulation
from .conftest import generate_k_line_path, generate_k_space_simulation


class TestKSpace:
Expand Down Expand Up @@ -259,6 +259,31 @@ class TestKLinePath:
Test the `KLinePath` class defined in `numerical_settings.py`.
"""

@pytest.mark.parametrize(
'high_symmetry_path_names, high_symmetry_path_values, result',
[
(None, None, False),
([], [], False),
(['Gamma', 'X', 'Y'], None, False),
([], [[0, 0, 0], [0.5, 0, 0], [0, 0.5, 0]], False),
(['Gamma', 'X', 'Y'], [[0, 0, 0], [0.5, 0, 0], [0, 0.5, 0]], True),
],
)
def test_check_high_symmetry_path(
self,
high_symmetry_path_names: List[str],
high_symmetry_path_values: List[List[float]],
result: bool,
):
"""
Test the `_check_high_symmetry_path` private method.
"""
k_line_path = generate_k_line_path(
high_symmetry_path_names=high_symmetry_path_names,
high_symmetry_path_values=high_symmetry_path_values,
)
assert k_line_path._check_high_symmetry_path(logger) == result

def test_get_high_symmetry_path_norm(self, k_line_path: KLinePath):
"""
Test the `get_high_symmetry_path_norm` method.
Expand Down

0 comments on commit a9bdac2

Please sign in to comment.