diff --git a/src/aiida_quantumespresso_hp/workflows/hubbard.py b/src/aiida_quantumespresso_hp/workflows/hubbard.py index 7673964..d517163 100644 --- a/src/aiida_quantumespresso_hp/workflows/hubbard.py +++ b/src/aiida_quantumespresso_hp/workflows/hubbard.py @@ -421,6 +421,22 @@ def get_pseudos(self) -> dict: return results + def relabel_hubbard_structure(self, workchain) -> None: + """Relabel the Hubbard structure if new types have been detected.""" + from aiida_quantumespresso.utils.hubbard import is_intersite_hubbard + + if not is_intersite_hubbard(workchain.outputs.hubbard_structure.hubbard): + for site in workchain.outputs.hubbard.dict.sites: + if not site['type'] == site['new_type']: + result = structure_relabel_kinds( + self.ctx.current_hubbard_structure, workchain.outputs.hubbard, self.ctx.current_magnetic_moments + ) + self.ctx.current_hubbard_structure = result['hubbard_structure'] + if self.ctx.current_magnetic_moments is not None: + self.ctx.current_magnetic_moments = result['starting_magnetization'] + self.report('new types have been detected: relabeling the structure.') + return + def run_relax(self): """Run the PwRelaxWorkChain to run a relax PwCalculation.""" inputs = self.get_inputs(PwRelaxWorkChain, 'relax') @@ -578,14 +594,14 @@ def inspect_hp(self): if not self.should_check_convergence(): self.ctx.current_hubbard_structure = workchain.outputs.hubbard_structure + self.relabel_hubbard_structure(workchain) + if not self.inputs.meta_convergence: self.report('meta convergence is switched off, so not checking convergence of Hubbard parameters.') self.ctx.is_converged = True def check_convergence(self): """Check the convergence of the Hubbard parameters.""" - from aiida_quantumespresso.utils.hubbard import is_intersite_hubbard - workchain = self.ctx.workchains_hp[-1] # We store in memory the parameters before relabelling to make the comparison easier. @@ -601,18 +617,7 @@ def check_convergence(self): # We check if new types were created, in which case we relabel the `HubbardStructureData` self.ctx.current_hubbard_structure = workchain.outputs.hubbard_structure - - if not is_intersite_hubbard(workchain.outputs.hubbard_structure.hubbard): - for site in workchain.outputs.hubbard.dict.sites: - if not site['type'] == site['new_type']: - self.report('new types have been detected: relabeling the structure and starting new iteration.') - result = structure_relabel_kinds( - self.ctx.current_hubbard_structure, workchain.outputs.hubbard, self.ctx.current_magnetic_moments - ) - self.ctx.current_hubbard_structure = result['hubbard_structure'] - if self.ctx.current_magnetic_moments is not None: - self.ctx.current_magnetic_moments = result['starting_magnetization'] - break + self.relabel_hubbard_structure(workchain) if not len(ref_params) == len(new_params): self.report('The new and old Hubbard parameters have different lenghts. Assuming to be at the first cycle.') diff --git a/tests/workflows/test_hubbard.py b/tests/workflows/test_hubbard.py index 2a06584..c24802c 100644 --- a/tests/workflows/test_hubbard.py +++ b/tests/workflows/test_hubbard.py @@ -227,6 +227,41 @@ def test_skip_relax_iterations(generate_workchain_hubbard, generate_inputs_hubba assert process.should_check_convergence() +@pytest.mark.usefixtures('aiida_profile') +def test_skip_relax_iterations_relabeling( + generate_workchain_hubbard, generate_inputs_hubbard, generate_hp_workchain_node, generate_hubbard_structure +): + """Test `SelfConsistentHubbardWorkChain` when skipping the first relax iterations and relabeling is needed.""" + from aiida.orm import Bool, Int + + inputs = generate_inputs_hubbard() + inputs['skip_relax_iterations'] = Int(1) + inputs['meta_convergence'] = Bool(True) + process = generate_workchain_hubbard(inputs=inputs) + process.setup() + + current_hubbard_structure = generate_hubbard_structure(u_value=1, only_u=True) + process.current_hubbard_structure = current_hubbard_structure + # 1 + process.update_iteration() + assert process.ctx.skip_relax_iterations == 1 + assert process.ctx.iteration == 1 + assert not process.should_run_relax() + assert not process.should_check_convergence() + process.ctx.workchains_hp = [generate_hp_workchain_node(relabel=True, u_value=1, only_u=True)] + process.inspect_hp() + assert process.ctx.current_hubbard_structure.get_kind_names( + ) != process.ctx.workchains_hp[-1].outputs.hubbard_structure.get_kind_names() + # 2 + process.update_iteration() + assert process.should_run_relax() + assert process.should_check_convergence() + # 3 + process.update_iteration() + assert process.should_run_relax() + assert process.should_check_convergence() + + @pytest.mark.usefixtures('aiida_profile') def test_relax_frequency(generate_workchain_hubbard, generate_inputs_hubbard): """Test `SelfConsistentHubbardWorkChain` when `relax_frequency` is different from 1."""