From 3cbb164dd30d1ccf3918d9d04227378be17404b1 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis <153118171+akoumpa@users.noreply.github.com> Date: Tue, 25 Jun 2024 08:32:53 -0700 Subject: [PATCH] mcore distOpt restore fix (#9421) Signed-off-by: Alexandros Koumparoulis --- nemo/collections/nlp/parts/nlp_overrides.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 0555776457a5..2fdb1906c31f 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -444,6 +444,9 @@ def _check_param_groups_mismatch(self, checkpoint_path: Union[str, Path], sharde bool: True if the number of param groups does not match """ common_state_dict = dist_checkpointing.load_common_state_dict(checkpoint_path) + # @akoumparouli: check if it contains an mcore dist opt + if common_state_dict.get('optimizer_states', [{}])[0].get('param_groups', None) is None: + return False model_param_groups = self._get_param_group(common_state_dict) checkpoint_param_groups = self._get_param_group(sharded_state_dict) return len(model_param_groups) != len(checkpoint_param_groups)