diff --git a/aiida_wannier90/io/_write_win.py b/aiida_wannier90/io/_write_win.py index 99022f8..1863fd4 100644 --- a/aiida_wannier90/io/_write_win.py +++ b/aiida_wannier90/io/_write_win.py @@ -13,6 +13,7 @@ from aiida.common import InputValidationError from ..utils import conv_to_fortran_withlists +from ..orbitals import generate_projections from ._group_list import list_to_grouped_string __all__ = ('write_win', ) @@ -46,8 +47,9 @@ def write_win( # pylint: disable=too-many-arguments :type kpoint_path: aiida.orm.nodes.data.dict.Dict :param projections: Orbitals used for the projections. Can be specified either as AiiDA class :py:class:`OrbitalData `, - or as a list of strings specifying the projections in Wannier90's format. - :type projections: aiida.orm.nodes.data.orbital.OrbitalData, aiida.orm.nodes.data.list.List[str] + or as a list of strings specifying the projections in Wannier90's format, + or as a list of dict in the format of the argument of `generate_projections`. + :type projections: aiida.orm.nodes.data.orbital.OrbitalData, aiida.orm.nodes.data.list.List[str], aiida.orm.nodes.data.list.List[dict] :param random_projections: If class :py:class:`OrbitalData ` is used for projections, enables random projections completion :type random_projections: aiida.orm.nodes.data.bool.Bool @@ -65,7 +67,7 @@ def write_win( # pylint: disable=too-many-arguments ) -def _create_win_string( # pylint: disable=too-many-branches,missing-function-docstring +def _create_win_string( # pylint: disable=too-many-branches,missing-function-docstring # noqa: MC0001 parameters, kpoints, structure=None, @@ -74,7 +76,7 @@ def _create_win_string( # pylint: disable=too-many-branches,missing-function-doc random_projections=False, ): from aiida.plugins import DataFactory - from aiida.orm import List + from aiida.orm import List, Dict # prepare the main input text input_file_lines = [] @@ -106,7 +108,31 @@ def _create_win_string( # pylint: disable=too-many-branches,missing-function-doc 'random_projections cannot be True if with List-type projections.' 'Instead, use "random" string as first element of the List.' ) - block_inputs['projections'] = projections.get_list() + lst = projections.get_list() + if all(isinstance(x, str) for x in lst): + block_inputs['projections'] = lst + elif all(isinstance(x, dict) for x in lst): + orbital_data = generate_projections(lst, structure=structure) + block_inputs['projections'] = _format_all_projections( + orbital_data, random_projections=True + ) + else: + raise InputValidationError( + 'Projections List contains wrong elements.' + 'They need to be either all strings or all dicts.' + ) + + elif isinstance(projections, Dict): + if random_projections: + raise InputValidationError( + 'random_projections cannot be True if with Dict-type projections.' + ) + orbital_data = generate_projections( + projections.get_dict(), structure=structure + ) + block_inputs['projections'] = _format_all_projections( + orbital_data, random_projections=True + ) else: block_inputs['projections'] = _format_all_projections( projections, random_projections=True diff --git a/tests/io/test_win_writer.py b/tests/io/test_win_writer.py index d7a83b5..81a9e77 100644 --- a/tests/io/test_win_writer.py +++ b/tests/io/test_win_writer.py @@ -10,6 +10,7 @@ import pytest from aiida.common.exceptions import InputValidationError +from aiida.orm import Dict, List def test_create_win_string(generate_win_params_gaas, file_regression): @@ -22,6 +23,58 @@ def test_create_win_string(generate_win_params_gaas, file_regression): ) +def test_create_win_string_projections_list_of_str( + generate_win_params_gaas, file_regression +): + """Test _write_win for parameter projections using a List of str.""" + from aiida_wannier90.io._write_win import _create_win_string + + gaas_params = generate_win_params_gaas() + gaas_params['projections'] = List(list=["Ga: p", "As: p"]) + file_regression.check( + _create_win_string(**gaas_params), encoding='utf-8', extension='.win' + ) + + +def test_create_win_string_projections_list_of_dict( + generate_win_params_gaas, file_regression +): + """Test _write_win for parameter projections using a List of dict.""" + from aiida_wannier90.io._write_win import _create_win_string + + gaas_params = generate_win_params_gaas() + gaas_params['projections'] = List( + list=[{ + "kind_name": "Ga", + "ang_mtm_name": ["p"] + }, { + "kind_name": "As", + "ang_mtm_name": ["p"] + }] + ) + file_regression.check( + _create_win_string(**gaas_params), encoding='utf-8', extension='.win' + ) + + +def test_create_win_string_projections_dict( + generate_win_params_gaas, file_regression +): + """Test _write_win for parameter projections using a List of str.""" + from aiida_wannier90.io._write_win import _create_win_string + + gaas_params = generate_win_params_gaas() + gaas_params['projections'] = Dict( + dict={ + "kind_name": "Ga", + "ang_mtm_name": ["p"] + } + ) + file_regression.check( + _create_win_string(**gaas_params), encoding='utf-8', extension='.win' + ) + + def test_exclude_bands(generate_kpoints_mesh, file_regression): """Test _write_win for parameter exclude_bands""" from aiida_wannier90.io._write_win import _create_win_string diff --git a/tests/io/test_win_writer/test_create_win_string_projections_dict.win b/tests/io/test_win_writer/test_create_win_string_projections_dict.win new file mode 100644 index 0000000..4622a5d --- /dev/null +++ b/tests/io/test_win_writer/test_create_win_string_projections_dict.win @@ -0,0 +1,48 @@ +mp_grid = 2, 2, 2 +num_iter = 12 +num_wann = 4 +wvfn_formatted = .true. + +begin atoms_cart +ang +Ga 0.0000000000 0.0000000000 0.0000000000 +As -1.4200000000 1.4200000000 1.4200000000 +end atoms_cart + +begin kpoint_path +G 0.0 0.0 0.0 X 0.5 0.0 0.5 +X 0.5 0.0 0.5 W 0.5 0.25 0.75 +W 0.5 0.25 0.75 K 0.375 0.375 0.75 +K 0.375 0.375 0.75 G 0.0 0.0 0.0 +G 0.0 0.0 0.0 L 0.5 0.5 0.5 +L 0.5 0.5 0.5 U 0.625 0.25 0.625 +U 0.625 0.25 0.625 W 0.5 0.25 0.75 +W 0.5 0.25 0.75 L 0.5 0.5 0.5 +L 0.5 0.5 0.5 K 0.375 0.375 0.75 +U 0.625 0.25 0.625 X 0.5 0.0 0.5 +end kpoint_path + +begin kpoints + 0.0000000000 0.0000000000 0.0000000000 + 0.0000000000 0.0000000000 0.5000000000 + 0.0000000000 0.5000000000 0.0000000000 + 0.0000000000 0.5000000000 0.5000000000 + 0.5000000000 0.0000000000 0.0000000000 + 0.5000000000 0.0000000000 0.5000000000 + 0.5000000000 0.5000000000 0.0000000000 + 0.5000000000 0.5000000000 0.5000000000 +end kpoints + +begin projections +random +c=0.0000000000,0.0000000000,0.0000000000:l=1,mr=1:::r=1: +c=0.0000000000,0.0000000000,0.0000000000:l=1,mr=2:::r=1: +c=0.0000000000,0.0000000000,0.0000000000:l=1,mr=3:::r=1: +end projections + +begin unit_cell_cart +ang + -2.8400000000 0.0000000000 2.8400000000 + 0.0000000000 2.8400000000 2.8400000000 + -2.8400000000 2.8400000000 0.0000000000 +end unit_cell_cart diff --git a/tests/io/test_win_writer/test_create_win_string_projections_list_of_dict.win b/tests/io/test_win_writer/test_create_win_string_projections_list_of_dict.win new file mode 100644 index 0000000..c8ab0eb --- /dev/null +++ b/tests/io/test_win_writer/test_create_win_string_projections_list_of_dict.win @@ -0,0 +1,51 @@ +mp_grid = 2, 2, 2 +num_iter = 12 +num_wann = 4 +wvfn_formatted = .true. + +begin atoms_cart +ang +Ga 0.0000000000 0.0000000000 0.0000000000 +As -1.4200000000 1.4200000000 1.4200000000 +end atoms_cart + +begin kpoint_path +G 0.0 0.0 0.0 X 0.5 0.0 0.5 +X 0.5 0.0 0.5 W 0.5 0.25 0.75 +W 0.5 0.25 0.75 K 0.375 0.375 0.75 +K 0.375 0.375 0.75 G 0.0 0.0 0.0 +G 0.0 0.0 0.0 L 0.5 0.5 0.5 +L 0.5 0.5 0.5 U 0.625 0.25 0.625 +U 0.625 0.25 0.625 W 0.5 0.25 0.75 +W 0.5 0.25 0.75 L 0.5 0.5 0.5 +L 0.5 0.5 0.5 K 0.375 0.375 0.75 +U 0.625 0.25 0.625 X 0.5 0.0 0.5 +end kpoint_path + +begin kpoints + 0.0000000000 0.0000000000 0.0000000000 + 0.0000000000 0.0000000000 0.5000000000 + 0.0000000000 0.5000000000 0.0000000000 + 0.0000000000 0.5000000000 0.5000000000 + 0.5000000000 0.0000000000 0.0000000000 + 0.5000000000 0.0000000000 0.5000000000 + 0.5000000000 0.5000000000 0.0000000000 + 0.5000000000 0.5000000000 0.5000000000 +end kpoints + +begin projections +random +c=0.0000000000,0.0000000000,0.0000000000:l=1,mr=1:::r=1: +c=0.0000000000,0.0000000000,0.0000000000:l=1,mr=2:::r=1: +c=0.0000000000,0.0000000000,0.0000000000:l=1,mr=3:::r=1: +c=-1.4200000000,1.4200000000,1.4200000000:l=1,mr=1:::r=1: +c=-1.4200000000,1.4200000000,1.4200000000:l=1,mr=2:::r=1: +c=-1.4200000000,1.4200000000,1.4200000000:l=1,mr=3:::r=1: +end projections + +begin unit_cell_cart +ang + -2.8400000000 0.0000000000 2.8400000000 + 0.0000000000 2.8400000000 2.8400000000 + -2.8400000000 2.8400000000 0.0000000000 +end unit_cell_cart diff --git a/tests/io/test_win_writer/test_create_win_string_projections_list_of_str.win b/tests/io/test_win_writer/test_create_win_string_projections_list_of_str.win new file mode 100644 index 0000000..32bac21 --- /dev/null +++ b/tests/io/test_win_writer/test_create_win_string_projections_list_of_str.win @@ -0,0 +1,46 @@ +mp_grid = 2, 2, 2 +num_iter = 12 +num_wann = 4 +wvfn_formatted = .true. + +begin atoms_cart +ang +Ga 0.0000000000 0.0000000000 0.0000000000 +As -1.4200000000 1.4200000000 1.4200000000 +end atoms_cart + +begin kpoint_path +G 0.0 0.0 0.0 X 0.5 0.0 0.5 +X 0.5 0.0 0.5 W 0.5 0.25 0.75 +W 0.5 0.25 0.75 K 0.375 0.375 0.75 +K 0.375 0.375 0.75 G 0.0 0.0 0.0 +G 0.0 0.0 0.0 L 0.5 0.5 0.5 +L 0.5 0.5 0.5 U 0.625 0.25 0.625 +U 0.625 0.25 0.625 W 0.5 0.25 0.75 +W 0.5 0.25 0.75 L 0.5 0.5 0.5 +L 0.5 0.5 0.5 K 0.375 0.375 0.75 +U 0.625 0.25 0.625 X 0.5 0.0 0.5 +end kpoint_path + +begin kpoints + 0.0000000000 0.0000000000 0.0000000000 + 0.0000000000 0.0000000000 0.5000000000 + 0.0000000000 0.5000000000 0.0000000000 + 0.0000000000 0.5000000000 0.5000000000 + 0.5000000000 0.0000000000 0.0000000000 + 0.5000000000 0.0000000000 0.5000000000 + 0.5000000000 0.5000000000 0.0000000000 + 0.5000000000 0.5000000000 0.5000000000 +end kpoints + +begin projections +Ga: p +As: p +end projections + +begin unit_cell_cart +ang + -2.8400000000 0.0000000000 2.8400000000 + 0.0000000000 2.8400000000 2.8400000000 + -2.8400000000 2.8400000000 0.0000000000 +end unit_cell_cart