-
Notifications
You must be signed in to change notification settings - Fork 2
/
config.py
57 lines (43 loc) · 1.62 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from pydantic import Field
from scaling.core import (
BaseConfig,
LearningRateSchedulerConfig,
OptimizerConfig,
ProfilerConfig,
RunnerConfig,
TopologyConfig,
TrainerConfig,
)
from scaling.core.logging import LoggerConfig
class MLPArchitectureConfig(BaseConfig):
n_hidden_layers: int = Field(
default=0, ge=0, description=("The number of layers in the network, excluding input and " "output layers.")
)
hidden_dim: int = Field(default=64, gt=0, description=("The number of hidden units in each hidden layer."))
class TrainingConfig(BaseConfig):
weight_decay: float = Field(0.0001, description="")
class MLPConfig(BaseConfig):
runner: RunnerConfig = Field(
RunnerConfig(),
description="",
)
logger: LoggerConfig = Field(
LoggerConfig(),
description="",
)
topology: TopologyConfig = Field(
TopologyConfig( # type: ignore[call-arg]
model_parallel_size=1,
pipe_parallel_size=1,
data_parallel_size=1,
micro_batch_size=2,
gradient_accumulation_steps=1,
),
description="",
)
optimizer: OptimizerConfig = Field(OptimizerConfig(), description="")
learning_rate_scheduler: LearningRateSchedulerConfig = Field(LearningRateSchedulerConfig(), description="")
training: TrainingConfig = Field(TrainingConfig(), description="")
trainer: TrainerConfig = Field(TrainerConfig(), description="")
profiler: ProfilerConfig = Field(ProfilerConfig(), description="")
architecture: MLPArchitectureConfig = Field(MLPArchitectureConfig(), description="")