Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added generate_projections for List and Dict #116

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 31 additions & 5 deletions aiida_wannier90/io/_write_win.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from aiida.common import InputValidationError

from ..utils import conv_to_fortran_withlists
from ..orbitals import generate_projections
from ._group_list import list_to_grouped_string

__all__ = ('write_win', )
Expand Down Expand Up @@ -46,8 +47,9 @@ def write_win( # pylint: disable=too-many-arguments
:type kpoint_path: aiida.orm.nodes.data.dict.Dict

:param projections: Orbitals used for the projections. Can be specified either as AiiDA class :py:class:`OrbitalData <aiida.orm.OrbitalData>`,
or as a list of strings specifying the projections in Wannier90's format.
:type projections: aiida.orm.nodes.data.orbital.OrbitalData, aiida.orm.nodes.data.list.List[str]
or as a list of strings specifying the projections in Wannier90's format,
or as a list of dict in the format of the argument of `generate_projections`.
:type projections: aiida.orm.nodes.data.orbital.OrbitalData, aiida.orm.nodes.data.list.List[str], aiida.orm.nodes.data.list.List[dict]

:param random_projections: If class :py:class:`OrbitalData <aiida.orm.OrbitalData>` is used for projections, enables random projections completion
:type random_projections: aiida.orm.nodes.data.bool.Bool
Expand All @@ -65,7 +67,7 @@ def write_win( # pylint: disable=too-many-arguments
)


def _create_win_string( # pylint: disable=too-many-branches,missing-function-docstring
def _create_win_string( # pylint: disable=too-many-branches,missing-function-docstring # noqa: MC0001
parameters,
kpoints,
structure=None,
Expand All @@ -74,7 +76,7 @@ def _create_win_string( # pylint: disable=too-many-branches,missing-function-doc
random_projections=False,
):
from aiida.plugins import DataFactory
from aiida.orm import List
from aiida.orm import List, Dict

# prepare the main input text
input_file_lines = []
Expand Down Expand Up @@ -106,7 +108,31 @@ def _create_win_string( # pylint: disable=too-many-branches,missing-function-doc
'random_projections cannot be True if with List-type projections.'
'Instead, use "random" string as first element of the List.'
)
block_inputs['projections'] = projections.get_list()
lst = projections.get_list()
if all(isinstance(x, str) for x in lst):
block_inputs['projections'] = lst
elif all(isinstance(x, dict) for x in lst):
orbital_data = generate_projections(lst, structure=structure)
block_inputs['projections'] = _format_all_projections(
orbital_data, random_projections=True
)
else:
raise InputValidationError(
'Projections List contains wrong elements.'
'They need to be either all strings or all dicts.'
)

elif isinstance(projections, Dict):
if random_projections:
raise InputValidationError(
'random_projections cannot be True if with Dict-type projections.'
)
orbital_data = generate_projections(
projections.get_dict(), structure=structure
)
block_inputs['projections'] = _format_all_projections(
orbital_data, random_projections=True
)
else:
block_inputs['projections'] = _format_all_projections(
projections, random_projections=True
Expand Down
53 changes: 53 additions & 0 deletions tests/io/test_win_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import pytest
from aiida.common.exceptions import InputValidationError
from aiida.orm import Dict, List


def test_create_win_string(generate_win_params_gaas, file_regression):
Expand All @@ -22,6 +23,58 @@ def test_create_win_string(generate_win_params_gaas, file_regression):
)


def test_create_win_string_projections_list_of_str(
generate_win_params_gaas, file_regression
):
"""Test _write_win for parameter projections using a List of str."""
from aiida_wannier90.io._write_win import _create_win_string

gaas_params = generate_win_params_gaas()
gaas_params['projections'] = List(list=["Ga: p", "As: p"])
file_regression.check(
_create_win_string(**gaas_params), encoding='utf-8', extension='.win'
)


def test_create_win_string_projections_list_of_dict(
generate_win_params_gaas, file_regression
):
"""Test _write_win for parameter projections using a List of dict."""
from aiida_wannier90.io._write_win import _create_win_string

gaas_params = generate_win_params_gaas()
gaas_params['projections'] = List(
list=[{
"kind_name": "Ga",
"ang_mtm_name": ["p"]
}, {
"kind_name": "As",
"ang_mtm_name": ["p"]
}]
)
file_regression.check(
_create_win_string(**gaas_params), encoding='utf-8', extension='.win'
)


def test_create_win_string_projections_dict(
generate_win_params_gaas, file_regression
):
"""Test _write_win for parameter projections using a List of str."""
from aiida_wannier90.io._write_win import _create_win_string

gaas_params = generate_win_params_gaas()
gaas_params['projections'] = Dict(
dict={
"kind_name": "Ga",
"ang_mtm_name": ["p"]
}
)
file_regression.check(
_create_win_string(**gaas_params), encoding='utf-8', extension='.win'
)


def test_exclude_bands(generate_kpoints_mesh, file_regression):
"""Test _write_win for parameter exclude_bands"""
from aiida_wannier90.io._write_win import _create_win_string
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
mp_grid = 2, 2, 2
num_iter = 12
num_wann = 4
wvfn_formatted = .true.

begin atoms_cart
ang
Ga 0.0000000000 0.0000000000 0.0000000000
As -1.4200000000 1.4200000000 1.4200000000
end atoms_cart

begin kpoint_path
G 0.0 0.0 0.0 X 0.5 0.0 0.5
X 0.5 0.0 0.5 W 0.5 0.25 0.75
W 0.5 0.25 0.75 K 0.375 0.375 0.75
K 0.375 0.375 0.75 G 0.0 0.0 0.0
G 0.0 0.0 0.0 L 0.5 0.5 0.5
L 0.5 0.5 0.5 U 0.625 0.25 0.625
U 0.625 0.25 0.625 W 0.5 0.25 0.75
W 0.5 0.25 0.75 L 0.5 0.5 0.5
L 0.5 0.5 0.5 K 0.375 0.375 0.75
U 0.625 0.25 0.625 X 0.5 0.0 0.5
end kpoint_path

begin kpoints
0.0000000000 0.0000000000 0.0000000000
0.0000000000 0.0000000000 0.5000000000
0.0000000000 0.5000000000 0.0000000000
0.0000000000 0.5000000000 0.5000000000
0.5000000000 0.0000000000 0.0000000000
0.5000000000 0.0000000000 0.5000000000
0.5000000000 0.5000000000 0.0000000000
0.5000000000 0.5000000000 0.5000000000
end kpoints

begin projections
random
c=0.0000000000,0.0000000000,0.0000000000:l=1,mr=1:::r=1:
c=0.0000000000,0.0000000000,0.0000000000:l=1,mr=2:::r=1:
c=0.0000000000,0.0000000000,0.0000000000:l=1,mr=3:::r=1:
end projections

begin unit_cell_cart
ang
-2.8400000000 0.0000000000 2.8400000000
0.0000000000 2.8400000000 2.8400000000
-2.8400000000 2.8400000000 0.0000000000
end unit_cell_cart
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
mp_grid = 2, 2, 2
num_iter = 12
num_wann = 4
wvfn_formatted = .true.

begin atoms_cart
ang
Ga 0.0000000000 0.0000000000 0.0000000000
As -1.4200000000 1.4200000000 1.4200000000
end atoms_cart

begin kpoint_path
G 0.0 0.0 0.0 X 0.5 0.0 0.5
X 0.5 0.0 0.5 W 0.5 0.25 0.75
W 0.5 0.25 0.75 K 0.375 0.375 0.75
K 0.375 0.375 0.75 G 0.0 0.0 0.0
G 0.0 0.0 0.0 L 0.5 0.5 0.5
L 0.5 0.5 0.5 U 0.625 0.25 0.625
U 0.625 0.25 0.625 W 0.5 0.25 0.75
W 0.5 0.25 0.75 L 0.5 0.5 0.5
L 0.5 0.5 0.5 K 0.375 0.375 0.75
U 0.625 0.25 0.625 X 0.5 0.0 0.5
end kpoint_path

begin kpoints
0.0000000000 0.0000000000 0.0000000000
0.0000000000 0.0000000000 0.5000000000
0.0000000000 0.5000000000 0.0000000000
0.0000000000 0.5000000000 0.5000000000
0.5000000000 0.0000000000 0.0000000000
0.5000000000 0.0000000000 0.5000000000
0.5000000000 0.5000000000 0.0000000000
0.5000000000 0.5000000000 0.5000000000
end kpoints

begin projections
random
c=0.0000000000,0.0000000000,0.0000000000:l=1,mr=1:::r=1:
c=0.0000000000,0.0000000000,0.0000000000:l=1,mr=2:::r=1:
c=0.0000000000,0.0000000000,0.0000000000:l=1,mr=3:::r=1:
c=-1.4200000000,1.4200000000,1.4200000000:l=1,mr=1:::r=1:
c=-1.4200000000,1.4200000000,1.4200000000:l=1,mr=2:::r=1:
c=-1.4200000000,1.4200000000,1.4200000000:l=1,mr=3:::r=1:
end projections

begin unit_cell_cart
ang
-2.8400000000 0.0000000000 2.8400000000
0.0000000000 2.8400000000 2.8400000000
-2.8400000000 2.8400000000 0.0000000000
end unit_cell_cart
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
mp_grid = 2, 2, 2
num_iter = 12
num_wann = 4
wvfn_formatted = .true.

begin atoms_cart
ang
Ga 0.0000000000 0.0000000000 0.0000000000
As -1.4200000000 1.4200000000 1.4200000000
end atoms_cart

begin kpoint_path
G 0.0 0.0 0.0 X 0.5 0.0 0.5
X 0.5 0.0 0.5 W 0.5 0.25 0.75
W 0.5 0.25 0.75 K 0.375 0.375 0.75
K 0.375 0.375 0.75 G 0.0 0.0 0.0
G 0.0 0.0 0.0 L 0.5 0.5 0.5
L 0.5 0.5 0.5 U 0.625 0.25 0.625
U 0.625 0.25 0.625 W 0.5 0.25 0.75
W 0.5 0.25 0.75 L 0.5 0.5 0.5
L 0.5 0.5 0.5 K 0.375 0.375 0.75
U 0.625 0.25 0.625 X 0.5 0.0 0.5
end kpoint_path

begin kpoints
0.0000000000 0.0000000000 0.0000000000
0.0000000000 0.0000000000 0.5000000000
0.0000000000 0.5000000000 0.0000000000
0.0000000000 0.5000000000 0.5000000000
0.5000000000 0.0000000000 0.0000000000
0.5000000000 0.0000000000 0.5000000000
0.5000000000 0.5000000000 0.0000000000
0.5000000000 0.5000000000 0.5000000000
end kpoints

begin projections
Ga: p
As: p
end projections

begin unit_cell_cart
ang
-2.8400000000 0.0000000000 2.8400000000
0.0000000000 2.8400000000 2.8400000000
-2.8400000000 2.8400000000 0.0000000000
end unit_cell_cart