diff --git a/src/aiida_quantumespresso_hp/workflows/hubbard.py b/src/aiida_quantumespresso_hp/workflows/hubbard.py index 529c982..ab6f9ff 100644 --- a/src/aiida_quantumespresso_hp/workflows/hubbard.py +++ b/src/aiida_quantumespresso_hp/workflows/hubbard.py @@ -615,6 +615,8 @@ def inspect_hp(self): def check_convergence(self): """Check the convergence of the Hubbard parameters.""" + from aiida_quantumespresso.utils.hubbard import is_intersite_hubbard + workchain = self.ctx.workchains_hp[-1] # We store in memory the parameters before relabelling to make the comparison easier. @@ -631,16 +633,17 @@ def check_convergence(self): # We check if new types were created, in which case we relabel the `HubbardStructureData` self.ctx.current_hubbard_structure = workchain.outputs.hubbard_structure - for site in workchain.outputs.hubbard.dict.sites: - if not site['type'] == site['new_type']: - self.report('new types have been detected: relabeling the structure and starting new iteration.') - result = structure_relabel_kinds( - self.ctx.current_hubbard_structure, workchain.outputs.hubbard, self.ctx.current_magnetic_moments - ) - self.ctx.current_hubbard_structure = result['hubbard_structure'] - if self.ctx.current_magnetic_moments is not None: - self.ctx.current_magnetic_moments = result['starting_magnetization'] - break + if not is_intersite_hubbard(workchain.outputs.hubbard_structure.hubbard): + for site in workchain.outputs.hubbard.dict.sites: + if not site['type'] == site['new_type']: + self.report('new types have been detected: relabeling the structure and starting new iteration.') + result = structure_relabel_kinds( + self.ctx.current_hubbard_structure, workchain.outputs.hubbard, self.ctx.current_magnetic_moments + ) + self.ctx.current_hubbard_structure = result['hubbard_structure'] + if self.ctx.current_magnetic_moments is not None: + self.ctx.current_magnetic_moments = result['starting_magnetization'] + break if not len(ref_params) == len(new_params): self.report('The new and old Hubbard parameters have different lenghts. Assuming to be at the first cycle.') diff --git a/tests/workflows/test_hubbard.py b/tests/workflows/test_hubbard.py index 41482d4..feda8fe 100644 --- a/tests/workflows/test_hubbard.py +++ b/tests/workflows/test_hubbard.py @@ -329,8 +329,6 @@ def test_not_converged_check_convergence( process.setup() - # Mocking current (i.e. "old") and "new" HubbardStructureData, - # containing different Hubbard parameters process.ctx.current_hubbard_structure = generate_hubbard_structure() process.ctx.workchains_hp = [generate_hp_workchain_node(u_value=5.0)] @@ -354,19 +352,26 @@ def test_relabel_check_convergence( process.setup() - # Mocking current (i.e. "old") and "new" HubbardStructureData, - # containing different Hubbard parameters - process.ctx.current_hubbard_structure = generate_hubbard_structure() - process.ctx.workchains_hp = [generate_hp_workchain_node(relabel=True, u_value=100)] - + current_hubbard_structure = generate_hubbard_structure(u_value=1, only_u=True) + process.ctx.current_hubbard_structure = current_hubbard_structure + process.ctx.workchains_hp = [generate_hp_workchain_node(relabel=True, u_value=100, only_u=True)] process.check_convergence() assert not process.ctx.is_converged + assert process.ctx.current_hubbard_structure.get_kind_names() != current_hubbard_structure.get_kind_names() - process.ctx.current_hubbard_structure = generate_hubbard_structure(u_value=99.99) - process.ctx.workchains_hp = [generate_hp_workchain_node(relabel=True, u_value=100)] + current_hubbard_structure = generate_hubbard_structure(u_value=99.99, only_u=True) + process.ctx.current_hubbard_structure = current_hubbard_structure + process.ctx.workchains_hp = [generate_hp_workchain_node(relabel=True, u_value=100, only_u=True)] + process.check_convergence() + assert process.ctx.is_converged + assert process.ctx.current_hubbard_structure.get_kind_names() != current_hubbard_structure.get_kind_names() + current_hubbard_structure = generate_hubbard_structure(u_value=99.99) + process.ctx.current_hubbard_structure = current_hubbard_structure + process.ctx.workchains_hp = [generate_hp_workchain_node(relabel=True, u_value=100)] process.check_convergence() assert process.ctx.is_converged + assert process.ctx.current_hubbard_structure.get_kind_names() == current_hubbard_structure.get_kind_names() @pytest.mark.usefixtures('aiida_profile')