Skip to content

Commit

Permalink
Added Nanotron logging
Browse files Browse the repository at this point in the history
  • Loading branch information
TJ-Solergibert committed May 22, 2024
1 parent de81b53 commit a28c532
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 55 deletions.
57 changes: 30 additions & 27 deletions tools/llama3/convert_hf_to_nanotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,23 @@

import torch
import yaml
from nanotron.config import Config, GeneralArgs, ModelArgs, ParallelismArgs, TokenizerArgs
from nanotron import logging
from nanotron.config import Config, GeneralArgs, LoggingArgs, ModelArgs, ParallelismArgs, TokenizerArgs
from nanotron.config.models_config import ExistingCheckpointInit
from nanotron.config.models_config import LlamaConfig as LlamaConfigNanotron
from nanotron.logging import log_rank, set_ranks_logging_level
from nanotron.models import build_model
from nanotron.models.llama import LlamaForTraining
from nanotron.parallel import ParallelContext
from nanotron.parallel.parameters import sanity_check
from nanotron.parallel.pipeline_parallel.engine import AllForwardAllBackwardPipelineEngine
from nanotron.parallel.tensor_parallel.nn import TensorParallelLinearMode
from nanotron.serialize import TrainingMetadata, save_meta, save_weights
from nanotron.serialize.metadata import DataStageMetadata
from nanotron.trainer import mark_tied_parameters
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

logger = logging.get_logger(__name__)

DEVICE = torch.device("cpu")
TORCH_DTYPE = torch.bfloat16

Expand Down Expand Up @@ -52,27 +54,23 @@ def get_args():

def main(args):
# Init Nanotron Parallel Utilities
parallel_config = ParallelismArgs(
dp=1,
pp=1,
tp=1,
pp_engine=AllForwardAllBackwardPipelineEngine(),
tp_mode=TensorParallelLinearMode.ALL_REDUCE,
tp_linear_async_communication=False,
)
assert (
parallel_config.tp_mode == TensorParallelLinearMode.ALL_REDUCE
and parallel_config.tp_linear_async_communication is False
)
parallel_config = ParallelismArgs(dp=1, pp=1, tp=1)

parallel_context = ParallelContext(
data_parallel_size=parallel_config.dp,
pipeline_parallel_size=parallel_config.pp,
tensor_parallel_size=parallel_config.tp,
)

set_ranks_logging_level(parallel_context=parallel_context, logging_config=LoggingArgs())

# Load Llama3-8B HF model
print(f"Loading pretrained Llama3 Model: {args.pretrained_model_name_or_path}")
log_rank(
f"Loading pretrained Llama3 Model: {args.pretrained_model_name_or_path}",
logger=logger,
level=logging.INFO,
rank=0,
)
hf_model = AutoModelForCausalLM.from_pretrained(
args.pretrained_model_name_or_path, torch_dtype=TORCH_DTYPE, attn_implementation="flash_attention_2"
).to(DEVICE)
Expand Down Expand Up @@ -103,7 +101,7 @@ def main(args):
)

# Init Llama3-8B Nanotron model
print("Init empty Nanotron Llama3 Model")
log_rank("Init empty Nanotron Llama3 Model", logger=logger, level=logging.INFO, rank=0)
nanotron_model = build_model(
model_builder=lambda: LlamaForTraining(
config=nanotron_llama_config,
Expand All @@ -120,9 +118,9 @@ def main(args):
sanity_check(root_module=nanotron_model)

# Copy params from HF to Nanotron
print("Copyng weights from HF model to Nanotron model...")
log_rank("Copying weights from HF model to Nanotron model...", logger=logger, level=logging.INFO, rank=0)
# Token embeddings
print("Copyng Token Embeddings...")
log_rank("Copying Token Embeddings...", logger=logger, level=logging.INFO, rank=0)
assert (
nanotron_model.model.token_position_embeddings.pp_block.token_embedding.weight.shape
== hf_model.model.embed_tokens.weight.shape
Expand All @@ -135,7 +133,7 @@ def main(args):
# Decoder layers
for i in tqdm(
range(nanotron_llama_config.num_hidden_layers),
desc="Copyng Hidden Layers",
desc="Copying Hidden Layers",
total=nanotron_llama_config.num_hidden_layers,
):
# Input layer norm
Expand Down Expand Up @@ -207,24 +205,24 @@ def main(args):
)

# Last layer norm
print("Copyng Final Layer Norm...")
log_rank("Copying Final Layer Norm...", logger=logger, level=logging.INFO, rank=0)
assert nanotron_model.model.final_layer_norm.pp_block.weight.shape == hf_model.model.norm.weight.shape
with torch.no_grad():
nanotron_model.model.final_layer_norm.pp_block.weight.copy_(hf_model.model.norm.weight)

# LM_Head
print("Copyng LM Head...")
log_rank("Copying LM Head...", logger=logger, level=logging.INFO, rank=0)
assert nanotron_model.model.lm_head.pp_block.weight.shape == hf_model.lm_head.weight.shape
with torch.no_grad():
nanotron_model.model.lm_head.pp_block.weight.copy_(hf_model.lm_head.weight)

print("Copied weights from HF model to Nanotron model!")
log_rank("Copied weights from HF model to Nanotron model!", logger=logger, level=logging.INFO, rank=0)
# Store weights
nanotron_checkpoint_path = Path(args.nanotron_checkpoint_path)
save_weights(model=nanotron_model, parallel_context=parallel_context, root_folder=nanotron_checkpoint_path)

# Store metadata
print("Storing Nanotron model Configs and Metadata!")
log_rank("Storing Nanotron model Configs and Metadata!", logger=logger, level=logging.INFO, rank=0)
training_metadata = TrainingMetadata(
last_train_step=0,
consumed_train_samples=0,
Expand All @@ -248,14 +246,19 @@ def main(args):
),
tokenizer=TokenizerArgs(nanotron_checkpoint_path),
)
print("Saving config ...")
log_rank("Saving config ...", logger=logger, level=logging.INFO, rank=0)
yaml.dump(config.as_dict(), f)

with open(nanotron_checkpoint_path / "model_config.json", "w") as f:
print("Saving model config ...")
log_rank("Saving model config ...", logger=logger, level=logging.INFO, rank=0)
json.dump(asdict(nanotron_llama_config), f)

print(f"Checkpoint conversion finished, check {args.nanotron_checkpoint_path}")
log_rank(
f"Checkpoint conversion finished, check {args.nanotron_checkpoint_path}",
logger=logger,
level=logging.INFO,
rank=0,
)


if __name__ == "__main__":
Expand Down
59 changes: 31 additions & 28 deletions tools/llama3/convert_nanotron_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,21 @@
from pathlib import Path

import torch
from nanotron.config import Config, ParallelismArgs, get_config_from_file
from nanotron import logging
from nanotron.config import Config, LoggingArgs, ParallelismArgs, get_config_from_file
from nanotron.logging import log_rank, set_ranks_logging_level
from nanotron.models import build_model
from nanotron.models.llama import LlamaForTraining
from nanotron.parallel import ParallelContext
from nanotron.parallel.parameters import sanity_check
from nanotron.parallel.pipeline_parallel.engine import AllForwardAllBackwardPipelineEngine
from nanotron.parallel.tensor_parallel.nn import TensorParallelLinearMode
from nanotron.serialize import load_weights
from nanotron.trainer import mark_tied_parameters
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.models.llama import LlamaConfig as LlamaConfigHF

logger = logging.get_logger(__name__)

DEVICE = torch.device("cpu")
TORCH_DTYPE = torch.bfloat16

Expand All @@ -41,7 +43,6 @@ def get_args():
required=True,
help="A path to a directory to store the converted checkpoint",
)
# TODO Add push to hub

args = parser.parse_args()

Expand All @@ -50,34 +51,31 @@ def get_args():

def main(args):
# Init Nanotron Parallel Utilities
parallel_config = ParallelismArgs(
dp=1,
pp=1,
tp=1,
pp_engine=AllForwardAllBackwardPipelineEngine(),
tp_mode=TensorParallelLinearMode.ALL_REDUCE,
tp_linear_async_communication=False,
)
assert (
parallel_config.tp_mode == TensorParallelLinearMode.ALL_REDUCE
and parallel_config.tp_linear_async_communication is False
)
parallel_config = ParallelismArgs(dp=1, pp=1, tp=1)

parallel_context = ParallelContext(
data_parallel_size=parallel_config.dp,
pipeline_parallel_size=parallel_config.pp,
tensor_parallel_size=parallel_config.tp,
)

set_ranks_logging_level(parallel_context=parallel_context, logging_config=LoggingArgs())

# Load Nanotron checkpoint config
print(f"Loading Nanotron checkpoint config file: {os.path.join(args.nanotron_checkpoint_path, 'config.yaml')}")
log_rank(
f"Loading Nanotron checkpoint config file: {os.path.join(args.nanotron_checkpoint_path, 'config.yaml')}",
logger=logger,
level=logging.INFO,
rank=0,
)
nanotron_config = get_config_from_file(
os.path.join(args.nanotron_checkpoint_path, "config.yaml"), config_class=Config, model_config_class=None
)
nanotron_llama_config = nanotron_config.model.model_config

# Init Llama3-8B Nanotron model
print("Init empty Nanotron Llama3 Model")
log_rank("Init empty Nanotron Llama3 Model", logger=logger, level=logging.INFO, rank=0)

nanotron_model = build_model(
model_builder=lambda: LlamaForTraining(
config=nanotron_config.model.model_config,
Expand All @@ -94,23 +92,23 @@ def main(args):
sanity_check(root_module=nanotron_model)

# Load Nanotron Checkpoint
print("Loading Nanotron Llama3 Model...")
log_rank("Loading Nanotron Llama3 Model...", logger=logger, level=logging.INFO, rank=0)
load_weights(
model=nanotron_model, parallel_context=parallel_context, root_folder=Path(args.nanotron_checkpoint_path)
)

# Build empty HF Model
print("Init empty HF Llama3 Model")
log_rank("Init empty HF Llama3 Model", logger=logger, level=logging.INFO, rank=0)
hf_model = AutoModelForCausalLM.from_config( # WARN This takes a long time
config=LlamaConfigHF(**asdict(nanotron_llama_config)),
torch_dtype=TORCH_DTYPE,
attn_implementation="flash_attention_2",
).to(DEVICE)

# Copy params from Nanotron to HF
print("Copyng weights from Nanotron model to HF model...")
log_rank("Copying weights from Nanotron model to HF model...", logger=logger, level=logging.INFO, rank=0)
# Token embeddings
print("Copyng Token Embeddings...")
log_rank("Copying Token Embeddings...", logger=logger, level=logging.INFO, rank=0)
assert (
nanotron_model.model.token_position_embeddings.pp_block.token_embedding.weight.shape
== hf_model.model.embed_tokens.weight.shape
Expand All @@ -123,7 +121,7 @@ def main(args):
# Decoder layers
for i in tqdm(
range(nanotron_llama_config.num_hidden_layers),
desc="Copyng Hidden Layers",
desc="Copying Hidden Layers",
total=nanotron_llama_config.num_hidden_layers,
):
# Input layer norm
Expand Down Expand Up @@ -199,26 +197,31 @@ def main(args):
)

# Last layer norm
print("Copyng Final Layer Norm...")
log_rank("Copying Final Layer Norm...", logger=logger, level=logging.INFO, rank=0)
assert nanotron_model.model.final_layer_norm.pp_block.weight.shape == hf_model.model.norm.weight.shape
with torch.no_grad():
hf_model.model.norm.weight.copy_(nanotron_model.model.final_layer_norm.pp_block.weight)

# LM_Head
print("Copyng LM Head...")
log_rank("Copying LM Head...", logger=logger, level=logging.INFO, rank=0)
assert nanotron_model.model.lm_head.pp_block.weight.shape == hf_model.lm_head.weight.shape
with torch.no_grad():
hf_model.lm_head.weight.copy_(nanotron_model.model.lm_head.pp_block.weight)

print("Copied weights from Nanotron model to HF model!")
log_rank("Copied weights from Nanotron model to HF model!", logger=logger, level=logging.INFO, rank=0)
# Store weights
print("Storing HF model Checkpoint and Tokenizer!")
log_rank("Storing HF model Checkpoint and Tokenizer!", logger=logger, level=logging.INFO, rank=0)
hf_model.save_pretrained(args.hugging_face_checkpoint_path, from_pt=True)
# Store tokenizer
tokenizer = AutoTokenizer.from_pretrained(nanotron_config.tokenizer.tokenizer_name_or_path)
tokenizer.save_pretrained(args.hugging_face_checkpoint_path)

print(f"Checkpoint conversion finished, check {args.hugging_face_checkpoint_path}")
log_rank(
f"Checkpoint conversion finished, check {args.hugging_face_checkpoint_path}",
logger=logger,
level=logging.INFO,
rank=0,
)


if __name__ == "__main__":
Expand Down

0 comments on commit a28c532

Please sign in to comment.