Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Checkpoint 1.3 backwards compatibility #152

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
24 changes: 1 addition & 23 deletions scripts/fix_checkpoint_bad_naming.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,31 +13,9 @@
"""

import argparse
import os
import re
from pathlib import Path


def update_checkpoint(checkpoint_dir: str):
print(f"Updating checkpoint in {checkpoint_dir}")
for root, _, files in os.walk(checkpoint_dir):
for file in files:
if file.endswith(".safetensors"):
# r'(?<=model)_(model)' means match the string '_model' that is preceded by 'model'
if len(re.findall(r"(?<=model)_(model)", file)) == 0:
continue
# we remove second _model
new_file = re.sub(r"(?<=model)_(model)", "", file)
# we would have "model_weight.safetensors_pp-rank-0-of-1_tp-rank-0-of-2.safetensors"

# let's assert we have two matches of ".safetensors"
assert len(re.findall(r".safetensors", new_file)) == 2
# then we remove first match
new_file = re.sub(r".safetensors", "", new_file, count=1)
# so that we get "model_weight_pp-rank-0-of-1_tp-rank-0-of-2.safetensors"

print(f"Renaming {file} to {new_file}")
os.rename(os.path.join(root, file), os.path.join(root, new_file))
from nanotron.serialize.legacy import update_checkpoints_with_wrong_prefix as update_checkpoint


def main():
Expand Down
25 changes: 25 additions & 0 deletions src/nanotron/serialize/legacy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import os
import re
from pathlib import Path


def update_checkpoints_with_wrong_prefix(checkpoint_dir: Path):
print(f"Updating checkpoint in {checkpoint_dir}")
for root, _, files in os.walk(checkpoint_dir):
for file in files:
if file.endswith(".safetensors"):
# r'(?<=model)_(model)' means match the string '_model' that is preceded by 'model'
if len(re.findall(r"(?<=model)_(model)", file)) == 0:
continue
# we remove second _model
new_file = re.sub(r"(?<=model)_(model)", "", file)
# we would have "model_weight.safetensors_pp-rank-0-of-1_tp-rank-0-of-2.safetensors"

# let's assert we have two matches of ".safetensors"
assert len(re.findall(r".safetensors", new_file)) == 2
# then we remove first match
new_file = re.sub(r".safetensors", "", new_file, count=1)
# so that we get "model_weight_pp-rank-0-of-1_tp-rank-0-of-2.safetensors"

print(f"Renaming {file} to {new_file}")
os.rename(os.path.join(root, file), os.path.join(root, new_file))
23 changes: 16 additions & 7 deletions src/nanotron/serialize/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,31 +30,40 @@ def get_exp_tp_pp_rank_and_size_from(
def get_path(
tensor_name: str,
type: ObjectType,
exp_tp_pp_rank_and_size: Tuple[Tuple[int, int], Tuple[int, int]],
exp_tp_pp_rank_and_size: Tuple[Tuple[int, int], Tuple[int, int], Tuple[int, int]],
is_expert_sharded: bool,
prefix: Optional[Path] = None,
) -> List[str]:
return_all_matches: bool = False,
) -> Path | List[Path]:
suffix = tensor_name.split(".")
suffix_path, suffix_name = suffix[:-1], suffix[-1]

if exp_tp_pp_rank_and_size:
# We always show pp_rank and tp_rank if `exp_tp_pp_rank_and_size` is provided
(exp_rank, exp_size), (tp_rank, tp_size), (pp_rank, pp_size) = exp_tp_pp_rank_and_size
if not is_expert_sharded or exp_size == 1:
pattern = f"{type.value}_{suffix_name}*.safetensors"
suffix_name = (
f"{type.value}_{suffix_name}_pp-rank-{pp_rank}-of-{pp_size}_tp-rank-{tp_rank}-of-{tp_size}.safetensors"
)
else:
# We only show exp_rank if tensor is exp_sharded and exp_size > 1
pattern = f"{type.value}_{suffix_name}*exp-rank-{exp_rank}-of-{exp_size}.safetensors"
suffix_name = f"{type.value}_{suffix_name}_pp-rank-{pp_rank}-of-{pp_size}_tp-rank-{tp_rank}-of-{tp_size}_exp-rank-{exp_rank}-of-{exp_size}.safetensors"
else:
pattern = f"{type.value}_{suffix_name}*.safetensors"
suffix_name = f"{type.value}_{suffix_name}.safetensors"

suffix_path.append(suffix_name)
if prefix is None:
return suffix_path
if return_all_matches:
if prefix is None:
return list(Path(suffix_path[0]).joinpath(*suffix_path[1:]).glob(pattern))
else:
return list(prefix.joinpath(*suffix_path).glob(pattern))
else:
return prefix.joinpath(*suffix_path)
suffix_path.append(suffix_name)
if prefix is None:
return Path(suffix_path[0]).joinpath(*suffix_path[1:])
else:
return prefix.joinpath(*suffix_path)


def extract_tp_pp_rank_from_shard_path(shard_path: Path):
Expand Down
39 changes: 27 additions & 12 deletions src/nanotron/serialize/weights.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

Expand All @@ -15,6 +16,7 @@
from nanotron.logging import log_rank
from nanotron.parallel import ParallelContext
from nanotron.parallel.parameters import NanotronParameter, ShardedInfo, SlicesPair
from nanotron.serialize.legacy import update_checkpoints_with_wrong_prefix
from nanotron.serialize.metadata import CheckpointMetadata, TensorMetadata, load_meta
from nanotron.serialize.utils import (
ObjectType,
Expand Down Expand Up @@ -209,6 +211,17 @@ def load_weights(

checkpoint_version: Optional[Version] = None

# Determine if the checkpoint matches the duplicated ".safetensors" naming and fix it.
wrong_names = {
Path(root) / name
for root, _, names in os.walk(param_root_folder)
for name in names
if name.count(".safetensors") == 2
}
if len(wrong_names) > 0:
print(f"Note: Old checkpoints detected. Upgrading to v{CHECKPOINT_VERSION}")
update_checkpoints_with_wrong_prefix(param_root_folder)

filtered_state_dict = filtered_state_dict if filtered_state_dict is not None else model.state_dict()
param_shard_metadata = {}
for name, param_or_buffer in tqdm(
Expand Down Expand Up @@ -259,8 +272,8 @@ def load_weights(
)

if path.exists():
# If the exact path exists, then the topology did not change.
with safe_open(path, framework="pt", device=str(param.device)) as fi:
# TODO @thomasw21: Choose only a slice if we switch the TP topology
param_or_buffer[:] = fi.get_tensor("data")

elif not path.parent.exists():
Expand All @@ -276,20 +289,20 @@ def load_weights(
raise ValueError(
f"`{name}` is not a sharded parameter. It's possible you were expecting {path} to exist."
)
# TODO @thomasw21: Make so that we don't need to code this logic somewhere else than in `get_path`
sharded_info = param.get_sharded_info()
suffix = base_name.rsplit(".", 1)[-1]
shards_path = list(path.parent.glob(f"{ObjectType.MODEL.value}_{suffix}*.safetensors"))
shards_path = get_path(
base_name,
type=ObjectType.MODEL,
exp_tp_pp_rank_and_size=exp_tp_pp_rank_and_size,
prefix=param_root_folder,
is_expert_sharded=is_expert_sharded,
return_all_matches=True,
)

if len(shards_path) <= 0:
raise ValueError(
f"Could not find any shards {ObjectType.MODEL.value}_{suffix}*.safetensors in {path.parent}."
f"If you notice `.safetensors` in the middle of the name of some of the checkpoints files. You need to run `scripts/fix_checkpoint_bad_naming.py`."
)
raise ValueError(f"Could not find any shards in {path.parent}.")

if checkpoint_version is None:
checkpoint_version = get_checkpoint_version(
parallel_context, root_folder, param_save_path=shards_path[0]
)
checkpoint_version = get_checkpoint_version(parallel_context, root_folder, shards_path[0])
else:
current_checkpoint_version = None
try:
Expand All @@ -304,6 +317,8 @@ def load_weights(
current_checkpoint_version == checkpoint_version
), f"Checkpoint version mismatch at {shards_path[0]}."

# Load the param.
sharded_info = param.get_sharded_info()
if checkpoint_version <= CHECKPOINT_VERSION:
load_sharded_param_latest(
param_or_buffer=param_or_buffer,
Expand Down
82 changes: 82 additions & 0 deletions tests/helpers/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer
from nanotron.parallel.tied_parameters import tie_parameters
from nanotron.parallel.utils import initial_sync
from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode
from nanotron.parallel.tensor_parallel.nn import TensorParallelColumnLinear, TensorParallelRowLinear
from torch import nn
from torch.nn.parallel import DistributedDataParallel

Expand Down Expand Up @@ -64,6 +66,86 @@ def forward(self, x: Union[torch.Tensor, TensorPointer]):
return x


class DummyParallelModel(nn.Module):
def __init__(self, p2p: P2P, tp_pg: dist.ProcessGroup, num_layers: int = 8, hidden_size: int = 16):
super().__init__()
self.num_layers = num_layers
self.hidden_size = hidden_size
self.p2p = p2p
self.mlp = nn.Sequential(
*(
nn.ModuleDict(
{
"linear1": PipelineBlock(
p2p=p2p,
module_builder=TensorParallelColumnLinear,
module_kwargs={"in_features": hidden_size, "out_features": hidden_size,
"pg": tp_pg, "mode": TensorParallelLinearMode.ALL_REDUCE,
"async_communication": True},
module_input_keys={"x"},
module_output_keys={"output"},
),
"activation": PipelineBlock(
p2p=p2p,
module_builder=nn.Sigmoid,
module_kwargs={},
module_input_keys={"input"},
module_output_keys={"output"},
),
"linear2": PipelineBlock(
p2p=p2p,
module_builder=TensorParallelRowLinear,
module_kwargs={"in_features": hidden_size, "out_features": hidden_size,
"pg": tp_pg, "mode": TensorParallelLinearMode.ALL_REDUCE},
module_input_keys={"x"},
module_output_keys={"output"},
),

}
)
for pp_rank in range(num_layers)
)
)

self.loss = PipelineBlock(
p2p=p2p,
module_builder=lambda: lambda x: x.sum(),
module_kwargs={},
module_input_keys={"x"},
module_output_keys={"output"},
)

def forward(self, x: torch.Tensor | TensorPointer, return_loss: bool = True):
for non_linear in self.mlp:
x = non_linear.linear1(x=x)["output"]
x = non_linear.activation(input=x)["output"]
x = non_linear.linear2(x=x)["output"]
if return_loss:
x = self.loss(x=x)["output"]
return x


def init_dummy_parallel_model(parallel_context: ParallelContext, dtype: torch.dtype = torch.float,
num_layers: int = 8, hidden_size: int = 16) -> DummyParallelModel:
p2p = P2P(pg=parallel_context.pp_pg, device=torch.device("cuda"))
model = DummyParallelModel(p2p=p2p, tp_pg=parallel_context.tp_pg, num_layers=num_layers, hidden_size=hidden_size)

# Build model.
pipeline_blocks = [module for name, module in model.named_modules() if isinstance(module, PipelineBlock)]
with init_on_device_and_dtype(device=torch.device("cuda"), dtype=dtype):
contiguous_size = ceil(len(pipeline_blocks) / parallel_context.pp_pg.size())
for i, block in enumerate(pipeline_blocks):
rank = i // contiguous_size
block.build_and_set_rank(rank)

initial_sync(model=model, parallel_context=parallel_context)

assert len(list(model.named_parameters())) > 0
model = DistributedDataParallel(model, process_group=parallel_context.dp_pg)

return model


def init_dummy_model(parallel_context: ParallelContext, dtype: torch.dtype = torch.float) -> DummyModel:
p2p = P2P(pg=parallel_context.pp_pg, device=torch.device("cuda"))
model = DummyModel(p2p=p2p)
Expand Down
92 changes: 88 additions & 4 deletions tests/test_serialize.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os
from pathlib import Path
import pytest
import torch
from helpers.context import TestContext
from helpers.dummy import dummy_infinite_data_loader, init_dummy_model
from helpers.dummy import dummy_infinite_data_loader, init_dummy_model, init_dummy_parallel_model
from helpers.utils import (
available_gpus,
get_all_3d_configurations,
Expand All @@ -22,6 +24,7 @@
AllForwardAllBackwardPipelineEngine,
)
from nanotron.parallel.sharded_parameters import SplitConfig, create_sharded_parameter_from_config
from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer
from nanotron.parallel.tied_parameters import sync_tied_weights_gradients
from nanotron.random import RandomStates, get_current_random_state, get_synced_random_state
from nanotron.serialize import (
Expand All @@ -36,9 +39,90 @@
from torch.nn.parallel import DistributedDataParallel


def test_save_and_load_with_changed_topolgy():
# TODO @thomasw21: We want to be able to support a change of topology mechanism
return
def _create_model_and_serialize_all(parallel_context: ParallelContext, root_path: Path):
# Create and save model.
model = init_dummy_parallel_model(parallel_context)
model = model.requires_grad_(False) # Otherwise we need to specify pipeline engine.

# Get inputs and outputs.
current_pp_rank = dist.get_rank(parallel_context.pp_pg)
if current_pp_rank == 0:
inputs = torch.randn(2, 16).cuda()
else:
inputs = TensorPointer(group_rank=0)
outputs = model(inputs, return_loss=False)

# Serialize everything.
save_weights(model=model, parallel_context=parallel_context, root_folder=root_path/"model")
rank = torch.distributed.get_rank()
world_size = int(os.environ["WORLD_SIZE"])
if rank == 0: # Save inputs on first process.
torch.save(inputs.detach().cpu(), root_path/"inputs.pt")
if rank == world_size - 1: # Save outputs on last process.
torch.save(outputs.detach().cpu(), root_path/"outputs.pt")


def _load_from_serialized(parallel_context: ParallelContext, load_path: Path, save_path: Path):
# Load inputs.
world_size = int(os.environ["WORLD_SIZE"])
current_pp_rank = dist.get_rank(parallel_context.pp_pg)
if current_pp_rank == 0:
inputs = torch.load(load_path/"inputs.pt").cuda()
else:
inputs = TensorPointer(group_rank=0)

# Create new random model.
model = init_dummy_parallel_model(parallel_context=parallel_context)
model = model.requires_grad_(False)
outputs = model(inputs, return_loss=False)
rank = torch.distributed.get_rank()
if rank == world_size - 1:
expected_outputs = torch.load(load_path/"outputs.pt")
assert not torch.allclose(outputs.detach().cpu(), expected_outputs)

# Load model and save correct outputs and checkpoint.
load_weights(model=model, parallel_context=parallel_context, root_folder=load_path/"model")
outputs = model(inputs, return_loss=False)
save_weights(model=model, parallel_context=parallel_context, root_folder=save_path/"model")
if rank == world_size - 1:
torch.save(outputs.detach().cpu(), save_path/"outputs.pt")


@pytest.mark.parametrize(
"tp1,dp1,pp1,tp2,dp2,pp2",
[
pytest.param(tp1, dp1, pp1, tp2, dp2, pp2)
for gpus in range(1, min(available_gpus(), 4) + 1)
for tp1, dp1, pp1 in get_all_3d_configurations(gpus)
for tp2, dp2, pp2 in get_all_3d_configurations(gpus)
if (tp1, dp1, pp1) != (tp2, dp2, pp2) and # ensure topology changed
16 % tp1 == 0 and 16 % tp2 == 0 and # 16 tensor dimension evenly divided
8 % pp1 == 0 and 8 % pp2 == 0 # 8 layers evenly divided
]
)
def test_save_and_load_with_changed_topolgy(tp1: int, dp1: int, pp1: int, tp2: int, dp2: int, pp2: int):
# Set up test.
print("Testing", tp1, dp1, pp1, tp2, dp2, pp2)
test_context = TestContext()
root = test_context.get_auto_remove_tmp_dir()
model1_path = root/"model1"
model2_path = root/"model2"
model1_path.mkdir()
model2_path.mkdir()

# Create first model.
init_distributed(tp=tp1, dp=dp1, pp=pp1)(_create_model_and_serialize_all)(root_path=model1_path)
assert (model1_path/"model").exists()
assert (model1_path/"inputs.pt").exists()
assert (model1_path/"outputs.pt").exists()

# Create second model and compare outputs.
init_distributed(tp=tp2, dp=dp2, pp=pp2)(_load_from_serialized)(load_path=model1_path, save_path=model2_path)
assert (model2_path/"outputs.pt").exists()

outputs1 = torch.load(model1_path/"outputs.pt")
outputs2 = torch.load(model2_path/"outputs.pt")
assert torch.allclose(outputs1, outputs2)


@pytest.mark.parametrize(
Expand Down