Skip to content

Commit

Permalink
Fix logic for momentum transfer and resolve_type
Browse files Browse the repository at this point in the history
  • Loading branch information
JosePizarro3 committed Apr 15, 2024
1 parent 8711451 commit ae7c744
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 18 deletions.
36 changes: 20 additions & 16 deletions src/nomad_simulations/properties/band_gap.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,25 +112,29 @@ def resolve_type(self, logger: BoundLogger) -> Optional[str]:
Returns:
(Optional[str]): The resolved `type` of the electronic band gap.
"""
if (
self.momentum_transfer is None or len(self.momentum_transfer) < 2
) and self.type == 'indirect':
mtr = self.momentum_transfer if self.momentum_transfer is not None else []

# Check if the `momentum_transfer` is [], and return the type and a warning in the log for `indirect` band gaps
if len(mtr) == 0:
if self.type == 'indirect':
logger.warning(
'The `momentum_transfer` is not stored for an `indirect` band gap.'
)
return self.type

# Check if the `momentum_transfer` has at least two elements, and return None if it does not
if len(mtr) == 1:
logger.warning(
"The `momentum_transfer` is not properly defined for an `type='indirect'` electronic band gap."
'The `momentum_transfer` should have at least two elements so that the difference can be calculated and the type of electronic band gap can be resolved.'
)
return None
if self.momentum_transfer is not None and len(self.momentum_transfer) > 0:
if len(self.momentum_transfer) == 1:
logger.warning(
'The `momentum_transfer` should have at least two elements so that the difference can be calculated and the type of electronic band gap can be resolved.'
)
return None
momentum_difference = np.diff(self.momentum_transfer, axis=0)
if (np.isclose(momentum_difference, np.zeros(3))).all():
return 'direct'
else:
return 'indirect'
return self.type

# Resolve `type` from the difference between the initial and final momentum transfer
momentum_difference = np.diff(mtr, axis=0)
if (np.isclose(momentum_difference, np.zeros(3))).all():
return 'direct'
else:
return 'indirect'

def normalize(self, archive, logger) -> None:
super().normalize(archive, logger)
Expand Down
8 changes: 6 additions & 2 deletions tests/test_band_gap.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,18 @@ def test_check_negative_values(
[
(None, None, None),
(None, 'direct', 'direct'),
(None, 'indirect', None),
(None, 'indirect', 'indirect'),
([], None, None),
([], 'direct', 'direct'),
([], 'indirect', None),
([], 'indirect', 'indirect'),
([[0, 0, 0]], None, None),
([[0, 0, 0]], 'direct', None),
([[0, 0, 0]], 'indirect', None),
([[0, 0, 0], [0, 0, 0]], None, 'direct'),
([[0, 0, 0], [0, 0, 0]], 'direct', 'direct'),
([[0, 0, 0], [0, 0, 0]], 'indirect', 'direct'),
([[0, 0, 0], [0.5, 0.5, 0.5]], None, 'indirect'),
([[0, 0, 0], [0.5, 0.5, 0.5]], 'direct', 'indirect'),
([[0, 0, 0], [0.5, 0.5, 0.5]], 'indirect', 'indirect'),
],
)
Expand Down

0 comments on commit ae7c744

Please sign in to comment.