Skip to content

Commit

Permalink
fixing kpoint path
Browse files Browse the repository at this point in the history
  • Loading branch information
Miki Bonacci committed Feb 19, 2024
1 parent e54dd58 commit 7abce1a
Showing 1 changed file with 20 additions and 8 deletions.
28 changes: 20 additions & 8 deletions aiida_yambo_wannier90/workflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from aiida_wannier90_workflows.utils.kpoints import (
get_explicit_kpoints,
get_mesh_from_kpoints,
get_path_from_kpoints
)
from aiida_wannier90_workflows.utils.workflows.builder.setter import set_kpoints
from aiida_wannier90_workflows.workflows import (
Expand Down Expand Up @@ -626,9 +627,7 @@ def setup(self) -> None: # pylint: disable=inconsistent-return-statements
"""Initialize context variables."""

self.ctx.current_structure = self.inputs.structure

if "bands_kpoints" in self.inputs:
self.ctx.bands_kpoints = self.inputs.bands_kpoints


# Converged mesh from YamboConvergence
self.ctx.kpoints_gw_conv = None
Expand Down Expand Up @@ -676,7 +675,13 @@ def setup(self) -> None: # pylint: disable=inconsistent-return-statements

def should_run_seekpath(self):
"""Run seekpath if the `inputs.bands_kpoints` is not provided."""
return "bands_kpoints" not in self.inputs
if "bands_kpoints" in self.inputs:
self.ctx.current_kpoint_path = get_path_from_kpoints(
self.inputs["bands_kpoints"]
)
return False
else:
return True

def run_seekpath(self):
"""Run the structure through SeeKpath to get the primitive and normalized structure."""
Expand All @@ -692,7 +697,11 @@ def run_seekpath(self):

self.ctx.current_structure = result["primitive_structure"]

self.ctx.current_bands_kpoints = result["explicit_kpoints"]
# Add `kpoint_path` for Wannier bands
self.ctx.current_kpoint_path = get_path_from_kpoints(
result["explicit_kpoints"]
)


structure_formula = self.inputs.structure.get_formula()
primitive_structure_formula = result["primitive_structure"].get_formula()
Expand Down Expand Up @@ -1060,7 +1069,8 @@ def prepare_wannier90_pp_inputs(self) -> AttributeDict:
params["bands_plot"] = False
inputs.wannier90.parameters = orm.Dict(params)

#inputs.wannier90.bands_kpoints = self.ctx.current_bands_kpoints
if self.ctx.current_kpoint_path:
inputs.wannier90.kpoint_path = self.ctx.current_kpoint_path

# Use commensurate kmesh
if self.ctx.kpoints_w90_input != self.ctx.kpoints_w90:
Expand Down Expand Up @@ -1172,7 +1182,8 @@ def prepare_wannier90_inputs(self) -> AttributeDict:
)

inputs.structure = self.ctx.current_structure
inputs.bands_kpoints = self.ctx.current_bands_kpoints
if self.ctx.current_kpoint_path:
inputs.wannier90.wannier90.kpoint_path = self.ctx.current_kpoint_path

# Use commensurate kmesh
if self.ctx.kpoints_w90_input != self.ctx.kpoints_w90:
Expand Down Expand Up @@ -1258,7 +1269,8 @@ def prepare_wannier90_qp_inputs(self) -> AttributeDict:
)

inputs.wannier90.structure = self.ctx.current_structure
inputs.wannier90.bands_kpoints = self.ctx.current_bands_kpoints
if self.ctx.current_kpoint_path:
inputs.kpoint_path = self.ctx.current_kpoint_path

if self.ctx.kpoints_w90_input != self.ctx.kpoints_w90:
set_kpoints(
Expand Down

0 comments on commit 7abce1a

Please sign in to comment.