Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: huvunvidia <[email protected]>
  • Loading branch information
huvunvidia committed Jan 2, 2025
1 parent db9e835 commit bece47f
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 17 deletions.
10 changes: 9 additions & 1 deletion nemo/collections/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,15 @@
from nemo.collections.llm.t5.data import MockDataModule as T5MockDataModule
from nemo.collections.llm.t5.data import PreTrainingDataModule as T5PreTrainingDataModule
from nemo.collections.llm.t5.data import SquadDataModule as T5SquadDataModule
from nemo.collections.llm.t5.model import T5Config, T5Config220M, T5Config3B, T5Config11B, T5Model, t5_data_step, t5_forward_step
from nemo.collections.llm.t5.model import (
T5Config,
T5Config3B,
T5Config11B,
T5Config220M,
T5Model,
t5_data_step,
t5_forward_step,
)

__all__ = [
"MockDataModule",
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/llm/t5/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from nemo.collections.llm.t5.model.t5 import (
MaskedTokenLossReduction,
T5Config,
T5Config220M,
T5Config3B,
T5Config11B,
T5Config220M,
T5Model,
local_layer_spec,
t5_data_step,
Expand Down
66 changes: 52 additions & 14 deletions nemo/collections/llm/t5/model/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@ def validation_loss_reduction(self) -> MaskedTokenLossReduction:

return self._validation_loss_reduction


@io.model_importer(T5Model, "hf")
class HFT5Importer(io.ModelConnector["T5ForConditionalGeneration", T5Model]):
def init(self) -> T5Model:
Expand Down Expand Up @@ -394,11 +395,17 @@ def convert_state(self, source, target):
del mapping["lm_head.weight"]

return io.apply_transforms(
source,
target,
mapping=mapping,
transforms=[_import_encoder_qkv, _import_encoder_linear_fc1, _import_decoder_qkv, _import_decoder_kv, _import_decoder_linear_fc1],
state_dict_ignored_entries=['output_layer.weight']
source,
target,
mapping=mapping,
transforms=[
_import_encoder_qkv,
_import_encoder_linear_fc1,
_import_decoder_qkv,
_import_decoder_kv,
_import_decoder_linear_fc1,
],
state_dict_ignored_entries=['output_layer.weight'],
)

@property
Expand Down Expand Up @@ -433,7 +440,7 @@ def make_vocab_size_divisible_by(vocab_size):
position_embedding_type="relative",
relative_attention_num_buckets=source.relative_attention_num_buckets,
relative_attention_max_distance=source.relative_attention_max_distance,
activation_func=F.gelu,
activation_func=F.gelu,
add_bias_linear=False,
init_method_std=source.initializer_factor,
normalization="RMSNorm",
Expand All @@ -449,6 +456,7 @@ def make_vocab_size_divisible_by(vocab_size):

return output


@io.state_transform(
source_key=(
"encoder.block.*.layer.0.SelfAttention.q.weight",
Expand Down Expand Up @@ -481,6 +489,7 @@ def _import_encoder_qkv(ctx: io.TransformCTX, q, k, v):

return qkv_weights


@io.state_transform(
source_key=(
"decoder.block.*.layer.0.SelfAttention.q.weight",
Expand Down Expand Up @@ -513,6 +522,7 @@ def _import_decoder_qkv(ctx: io.TransformCTX, q, k, v):

return qkv_weights


@io.state_transform(
source_key=(
"decoder.block.*.layer.1.EncDecAttention.k.weight",
Expand Down Expand Up @@ -541,16 +551,24 @@ def _import_decoder_kv(ctx: io.TransformCTX, k, v):
kv_weights = kv_weights.reshape([head_size * (2 * head_num), hidden_size])

return kv_weights



@io.state_transform(
source_key=("encoder.block.*.layer.1.DenseReluDense.wi_0.weight", "encoder.block.*.layer.1.DenseReluDense.wi_1.weight"),
source_key=(
"encoder.block.*.layer.1.DenseReluDense.wi_0.weight",
"encoder.block.*.layer.1.DenseReluDense.wi_1.weight",
),
target_key="encoder.layers.*.mlp.linear_fc1.weight",
)
def _import_encoder_linear_fc1(down, gate):
return torch.cat((down, gate), axis=0)


@io.state_transform(
source_key=("decoder.block.*.layer.2.DenseReluDense.wi_0.weight", "decoder.block.*.layer.2.DenseReluDense.wi_1.weight"),
source_key=(
"decoder.block.*.layer.2.DenseReluDense.wi_0.weight",
"decoder.block.*.layer.2.DenseReluDense.wi_1.weight",
),
target_key="decoder.layers.*.mlp.linear_fc1.weight",
)
def _import_decoder_linear_fc1(down, gate):
Expand Down Expand Up @@ -602,8 +620,14 @@ def convert_state(self, source, target):
source,
target,
mapping=mapping,
transforms=[_export_encoder_qkv, _export_encoder_linear_fc1, _export_decoder_qkv, _export_decoder_kv, _export_decoder_linear_fc1],
state_dict_ignored_entries=['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight']
transforms=[
_export_encoder_qkv,
_export_encoder_linear_fc1,
_export_decoder_qkv,
_export_decoder_kv,
_export_decoder_linear_fc1,
],
state_dict_ignored_entries=['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight'],
)

@property
Expand All @@ -628,6 +652,7 @@ def config(self) -> "HFT5Config":

def round_up_to_divisible(number, divisor):
import math

if divisor == 0:
raise ValueError("Divisor cannot be zero.")
return int(math.ceil(number / divisor) * divisor)
Expand All @@ -643,7 +668,9 @@ def round_up_to_divisible(number, divisor):
relative_attention_max_distance=source.relative_attention_max_distance,
initializer_factor=source.init_method_std,
layer_norm_epsilon=source.layernorm_epsilon,
vocab_size=round_up_to_divisible(self.tokenizer.vocab_size + len(self.tokenizer.additional_special_tokens), 128),
vocab_size=round_up_to_divisible(
self.tokenizer.vocab_size + len(self.tokenizer.additional_special_tokens), 128
),
feed_forward_proj="gated-gelu",
tie_word_embeddings=source.share_embeddings_and_output_weights,
decoder_start_token_id=bos_id,
Expand Down Expand Up @@ -686,6 +713,7 @@ def _export_encoder_qkv(ctx: io.TransformCTX, linear_qkv):

return q_proj, k_proj, v_proj


@io.state_transform(
source_key="decoder.layers.*.self_attention.linear_qkv.weight",
target_key=(
Expand Down Expand Up @@ -720,6 +748,7 @@ def _export_decoder_qkv(ctx: io.TransformCTX, linear_qkv):

return q_proj, k_proj, v_proj


@io.state_transform(
source_key="decoder.layers.*.cross_attention.linear_kv.weight",
target_key=(
Expand All @@ -744,24 +773,33 @@ def _export_decoder_kv(ctx: io.TransformCTX, linear_kv):

return k_proj, v_proj


@io.state_transform(
source_key="encoder.layers.*.mlp.linear_fc1.weight",
target_key=("encoder.block.*.layer.1.DenseReluDense.wi_0.weight", "encoder.block.*.layer.1.DenseReluDense.wi_1.weight"),
target_key=(
"encoder.block.*.layer.1.DenseReluDense.wi_0.weight",
"encoder.block.*.layer.1.DenseReluDense.wi_1.weight",
),
)
def _export_encoder_linear_fc1(linear_fc1):
gate_proj, up_proj = torch.chunk(linear_fc1, 2, dim=0)

return gate_proj, up_proj


@io.state_transform(
source_key="decoder.layers.*.mlp.linear_fc1.weight",
target_key=("decoder.block.*.layer.2.DenseReluDense.wi_0.weight", "decoder.block.*.layer.2.DenseReluDense.wi_1.weight"),
target_key=(
"decoder.block.*.layer.2.DenseReluDense.wi_0.weight",
"decoder.block.*.layer.2.DenseReluDense.wi_1.weight",
),
)
def _export_decoder_linear_fc1(linear_fc1):
gate_proj, up_proj = torch.chunk(linear_fc1, 2, dim=0)

return gate_proj, up_proj


__all__ = [
"T5Model",
"T5Config",
Expand Down
2 changes: 1 addition & 1 deletion nemo/lightning/io/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def apply_transforms(
are applied. Defaults to None.
state_dict_ignored_entries: List of entries to ignore in _target.state_dict(). There are cases
where multiple entries in model's state_dict point to one entry in model's named_parameter.
E.g., model has multiple pointers pointing to one shared parameters (`encoder.embed_tokens.weight`,
E.g., model has multiple pointers pointing to one shared parameters (`encoder.embed_tokens.weight`,
`decoder.embed_tokens.weight` and `shared.weight` all points to `shared.weight
in T5 Huggingface implementation.). In these cases, ignore redundant entries.
Expand Down

0 comments on commit bece47f

Please sign in to comment.