Skip to content

Commit

Permalink
Updated dynamic freezing to work with MegatronGPT
Browse files Browse the repository at this point in the history
  • Loading branch information
trias702 committed Sep 22, 2023
1 parent 2a6aa8a commit ee25d49
Showing 1 changed file with 20 additions and 3 deletions.
23 changes: 20 additions & 3 deletions nemo/core/classes/modelPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -1662,6 +1662,14 @@ def on_train_batch_start(self, batch: Any, batch_idx: int, unused: int = 0) -> O
if hasattr(self, '_freeze_cfg') and self._freeze_cfg is not None:
if self.training and hasattr(self, "trainer") and self.trainer is not None:
num_updates = self.trainer.global_step + 1
obj_to_set = None
# bit of a hack to ensure compatibility with MegatronGPT modules and FP16Module wrapper
if hasattr(self, 'model') and hasattr(self.model, 'module'):
obj_to_set = self.model.module
elif hasattr(self, 'model'):
obj_to_set = self.model
else:
obj_to_set = self

for ml, m_steps in self._freeze_cfg['modules'].items():
# we could do hasattr check here, but it's too expensive for each step
Expand All @@ -1673,11 +1681,20 @@ def on_train_batch_start(self, batch: Any, batch_idx: int, unused: int = 0) -> O
else:
should_freeze = num_updates <= m_steps or m_steps == -1
if should_freeze and not self._freeze_cfg['is_frozen'][ml]:
getattr(self, ml).freeze()
getattr(self, ml).train()
if hasattr(getattr(obj_to_set, ml), 'freeze'):
getattr(obj_to_set, ml).freeze()
else:
for param in getattr(obj_to_set, ml).parameters():
param.requires_grad = False
getattr(obj_to_set, ml).train()
self._freeze_cfg['is_frozen'][ml] = True
elif not should_freeze and self._freeze_cfg['is_frozen'][ml]:
getattr(self, ml).unfreeze()
if hasattr(getattr(obj_to_set, ml), 'unfreeze'):
getattr(obj_to_set, ml).unfreeze()
else:
for param in getattr(obj_to_set, ml).parameters():
param.requires_grad = True
getattr(obj_to_set, ml).train()
self._freeze_cfg['is_frozen'][ml] = False

def on_train_batch_end(self, outputs, batch: Any, batch_idx: int, unused: int = 0) -> None:
Expand Down

0 comments on commit ee25d49

Please sign in to comment.