-
Notifications
You must be signed in to change notification settings - Fork 2
/
train.py
62 lines (51 loc) · 2.05 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from examples.mlp_example.config import MLPConfig
from examples.mlp_example.context import MLPContext
from examples.mlp_example.data import MNISTDataset
from examples.mlp_example.model import init_model, init_optimizer, loss_function, metrics_aggregation_fn
from scaling.core import BaseTrainer
from scaling.core.logging import logger
from scaling.core.runner import LaunchConfig
from scaling.core.topology import Topology
def main(launch_config: LaunchConfig) -> None:
config_payload = launch_config.payload
assert config_payload is not None
topology_ = config_payload["topology"]
assert topology_ is not None
topology_["world_size"] = launch_config.world_size
topology_["global_rank"] = launch_config.global_rank
topology_["local_slot"] = launch_config.local_slot
config = MLPConfig.from_dict(config_payload)
topology = Topology(config=config.topology)
context = MLPContext(config=config, topology=topology)
logger.configure(
config=config.logger,
name=f"RANK {topology.config.global_rank}",
global_rank=topology.config.global_rank,
)
context.initialize(
master_addr=launch_config.master_addr,
master_port=str(launch_config.master_port),
seed=config.trainer.seed,
)
model = init_model(context=context)
optimizer = init_optimizer(context=context, model=model)
train_data = None
valid_data = None
if topology.is_io_rank:
train_data = MNISTDataset(train=True)
valid_data = MNISTDataset(train=False)
trainer = BaseTrainer(
config=context.config.trainer,
context=context,
parallel_module=model,
optimizer=optimizer,
dataset=train_data,
dataset_evaluation=valid_data,
sync_batch_to_model_parallel=MNISTDataset.sync_batch_to_model_parallel,
metrics_aggregation_fn=metrics_aggregation_fn,
loss_function=loss_function, # type: ignore[arg-type]
)
trainer.run_training()
if __name__ == "__main__":
launch_config = LaunchConfig.from_launcher_args()
main(launch_config)