diff --git a/src/aiida_wannier90_workflows/utils/pseudo/__init__.py b/src/aiida_wannier90_workflows/utils/pseudo/__init__.py index 0309982..239bc17 100644 --- a/src/aiida_wannier90_workflows/utils/pseudo/__init__.py +++ b/src/aiida_wannier90_workflows/utils/pseudo/__init__.py @@ -209,6 +209,7 @@ def get_wannier_number_of_bands( factor=1.2, only_valence=False, spin_polarized=False, + spin_non_collinear: bool = False, spin_orbit_coupling: bool = False, ): """Estimate number of bands for a Wannier90 calculation. @@ -220,6 +221,8 @@ def get_wannier_number_of_bands( :type only_valence: bool :param spin_polarized: magnetic calculation? :type spin_polarized: bool + :param spin_non_collinear: non-collinear or spin-orbit-coupling + :type spin_non_collinear: bool :param spin_orbit_coupling: spin orbit coupling calculation? :type spin_orbit_coupling: bool :return: number of bands for Wannier90 SCDM @@ -236,8 +239,10 @@ def get_wannier_number_of_bands( raise ValueError("Should use SOC pseudo for SOC calculation") num_electrons = get_number_of_electrons(structure, pseudos) - num_projections = get_number_of_projections(structure, pseudos, spin_orbit_coupling) - nspin = 2 if (spin_polarized or spin_orbit_coupling) else 1 + num_projections = get_number_of_projections( + structure, pseudos, spin_non_collinear, spin_orbit_coupling + ) + nspin = 2 if (spin_polarized or spin_non_collinear) else 1 # TODO check nospin, spin, soc # pylint: disable=fixme if only_valence: num_bands = int(0.5 * num_electrons * nspin) @@ -259,6 +264,7 @@ def get_wannier_number_of_bands_ext( factor=1.2, only_valence=False, spin_polarized=False, + spin_non_collinear: bool = False, spin_orbit_coupling: bool = False, ): """Estimate number of bands for a Wannier90 calculation. @@ -272,6 +278,8 @@ def get_wannier_number_of_bands_ext( :type only_valence: bool :param spin_polarized: magnetic calculation? :type spin_polarized: bool + :param spin_non_collinear: non-collinear or spin-orbit-coupling + :type spin_non_collinear: bool :param spin_orbit_coupling: spin orbit coupling calculation? :type spin_orbit_coupling: bool :return: number of bands for Wannier90 SCDM @@ -289,9 +297,9 @@ def get_wannier_number_of_bands_ext( num_electrons = get_number_of_electrons(structure, pseudos) num_projections = get_number_of_projections_ext( - structure, external_projectors, spin_orbit_coupling + structure, external_projectors, spin_non_collinear, spin_orbit_coupling ) - nspin = 2 if (spin_polarized or spin_orbit_coupling) else 1 + nspin = 2 if (spin_polarized or spin_non_collinear) else 1 # TODO check nospin, spin, soc # pylint: disable=fixme if only_valence: num_bands = int(0.5 * num_electrons * nspin) @@ -309,6 +317,7 @@ def get_wannier_number_of_bands_ext( def get_number_of_projections( structure: orm.StructureData, pseudos: ty.Mapping[str, orm.UpfData], + spin_non_collinear: bool, spin_orbit_coupling: ty.Optional[bool] = None, ) -> int: """Get number of projections for the structure with the given pseudopotential files. @@ -320,6 +329,8 @@ def get_number_of_projections( :type structure: aiida.orm.StructureData :param pseudos: a dictionary contains orm.UpfData of the structure :type pseudos: dict + :param spin_non_collinear: non-collinear or spin-orbit-coupling + :type spin_non_collinear: bool :return: number of projections :rtype: int """ @@ -357,12 +368,14 @@ def get_number_of_projections( upf = pseudos[kind] nprojs = get_number_of_projections_from_upf(upf) soc = is_soc_pseudo(get_upf_content(pseudos[kind])) - if spin_orbit_coupling and not soc: - # For SOC calculation with non-SOC pseudo, QE will generate + if spin_non_collinear and not soc: + # For magnetic calculation with non-SOC pseudo, QE will generate # 2 PSWFCs from each one PSWFC in the pseudo + # For collinear-magnetic calculation, spin up and down will calc + # seperately, so nprojs do not times 2 nprojs *= 2 - elif not spin_orbit_coupling and soc: - # For non-SOC calculation with SOC pseudo, QE will average + elif not spin_non_collinear and soc: + # For non-magnetic calculation with SOC pseudo, QE will average # the 2 PSWFCs into one nprojs //= 2 tot_nprojs += nprojs * composition[kind] @@ -373,6 +386,7 @@ def get_number_of_projections( def get_number_of_projections_ext( structure: orm.StructureData, external_projectors: dict, + spin_non_collinear: bool, spin_orbit_coupling: bool = False, ) -> int: """Get number of projections for the structure with the given projector dict. @@ -381,6 +395,8 @@ def get_number_of_projections_ext( :type structure: aiida.orm.StructureData :param projectors: a dictionary contains projector list of the structure :type pseudos: dict + :param spin_non_collinear: non-collinear or spin-orbit-coupling + :type spin_non_collinear: bool :return: number of projections :rtype: int """ @@ -407,6 +423,8 @@ def get_number_of_projections_ext( nprojs += round(2 * orb["j"]) + 1 else: nprojs += 2 * orb["l"] + 1 + if spin_non_collinear: + nprojs *= 2 tot_nprojs += nprojs * composition[kind] return tot_nprojs diff --git a/src/aiida_wannier90_workflows/workflows/base/wannier90.py b/src/aiida_wannier90_workflows/workflows/base/wannier90.py index 71c48e8..f8f49ea 100644 --- a/src/aiida_wannier90_workflows/workflows/base/wannier90.py +++ b/src/aiida_wannier90_workflows/workflows/base/wannier90.py @@ -330,11 +330,13 @@ def get_builder_from_protocol( factor=meta_parameters["num_bands_factor"], only_valence=only_valence, spin_polarized=spin_polarized, + spin_non_collinear=spin_non_collinear, spin_orbit_coupling=spin_orbit_coupling, ) num_projs = get_number_of_projections_ext( structure=structure, external_projectors=external_projectors, + spin_non_collinear=spin_non_collinear, spin_orbit_coupling=spin_orbit_coupling, ) else: @@ -344,11 +346,13 @@ def get_builder_from_protocol( factor=meta_parameters["num_bands_factor"], only_valence=only_valence, spin_polarized=spin_polarized, + spin_non_collinear=spin_non_collinear, spin_orbit_coupling=spin_orbit_coupling, ) num_projs = get_number_of_projections( structure=structure, pseudos=pseudos, + spin_non_collinear=spin_non_collinear, spin_orbit_coupling=spin_orbit_coupling, ) diff --git a/src/aiida_wannier90_workflows/workflows/wannier90.py b/src/aiida_wannier90_workflows/workflows/wannier90.py index 4fb3211..ef1d92e 100644 --- a/src/aiida_wannier90_workflows/workflows/wannier90.py +++ b/src/aiida_wannier90_workflows/workflows/wannier90.py @@ -995,11 +995,22 @@ def sanity_check(self): # pylint: disable=inconsistent-return-statements check_num_projs = True if self.should_run_scf(): pseudos = self.inputs["scf"]["pw"]["pseudos"] + spin_orbit_coupling = ( + self.inputs["scf"]["pw"]["parameters"] + .get_dict()["SYSTEM"] + .get("SYSTEM", False) + ) elif self.should_run_nscf(): pseudos = self.inputs["nscf"]["pw"]["pseudos"] + spin_orbit_coupling = ( + self.inputs["nscf"]["pw"]["parameters"] + .get_dict()["SYSTEM"] + .get("SYSTEM", False) + ) else: check_num_projs = False - pseudos = None # to avoid pylint errors + pseudos = None + spin_orbit_coupling = None if check_num_projs: args = { "structure": self.ctx.current_structure, @@ -1014,9 +1025,11 @@ def sanity_check(self): # pylint: disable=inconsistent-return-statements params = self.ctx.workchain_wannier90.inputs["wannier90"][ "parameters" ].get_dict() - spin_orbit_coupling = params.get("spinors", False) + spin_non_collinear = params.get("spinors", False) number_of_projections = get_number_of_projections( - **args, spin_orbit_coupling=spin_orbit_coupling + **args, + spin_non_collinear=spin_non_collinear, + spin_orbit_coupling=spin_orbit_coupling, ) if number_of_projections != num_proj: self.report(