diff --git a/nemo/collections/vlm/mllama/model/language.py b/nemo/collections/vlm/mllama/model/language.py index bec3ec526f6e..3edc6706defb 100644 --- a/nemo/collections/vlm/mllama/model/language.py +++ b/nemo/collections/vlm/mllama/model/language.py @@ -390,7 +390,7 @@ def sharded_state_dict( layer_prefix = f'{prefix}layers.' num_layers = self.config.num_layers for layer in self.layers: - offset = layer._get_layer_offset() + offset = layer._get_layer_offset(layer.config) global_layer_offset = layer.layer_number - 1 # self.layer_number starts at 1 state_dict_prefix = f'{layer_prefix}{global_layer_offset - offset}.' # module list index in TransformerBlock # pylint: disable=line-too-long sharded_prefix = layer_prefix @@ -403,7 +403,7 @@ def sharded_state_dict( for xlayer in self.xattn_layers: if isinstance(xlayer, DummyCrossAttentionTransformerLayer): continue - offset = xlayer._get_layer_offset() + offset = xlayer._get_layer_offset(xlayer.config) global_layer_offset = xlayer.layer_number - 1 state_dict_prefix = f'{xlayer_prefix}{global_layer_offset - offset}.' # module list index in TransformerBlock # pylint: disable=line-too-long sharded_prefix = f'{xlayer_prefix}{global_layer_offset}.'