-
Notifications
You must be signed in to change notification settings - Fork 122
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
docs: ✏️ update configs of iTransformer
- Loading branch information
Showing
2 changed files
with
160 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
import os | ||
import sys | ||
from easydict import EasyDict | ||
sys.path.append(os.path.abspath(__file__ + '/../../..')) | ||
from basicts.metrics import masked_mae, masked_mse, masked_mape, masked_rmse | ||
from basicts.data import TimeSeriesForecastingDataset | ||
from basicts.runners import SimpleTimeSeriesForecastingRunner | ||
from basicts.scaler import ZScoreScaler | ||
from basicts.utils import get_regular_settings | ||
|
||
from .arch import iTransformer | ||
|
||
############################## Hot Parameters ############################## | ||
# Dataset & Metrics configuration | ||
DATA_NAME = 'Traffic' # Dataset name | ||
regular_settings = get_regular_settings(DATA_NAME) | ||
INPUT_LEN = 96 | ||
OUTPUT_LEN = 720 | ||
TRAIN_VAL_TEST_RATIO = regular_settings['TRAIN_VAL_TEST_RATIO'] # Train/Validation/Test split ratios | ||
NORM_EACH_CHANNEL = regular_settings['NORM_EACH_CHANNEL'] # Whether to normalize each channel of the data | ||
RESCALE = regular_settings['RESCALE'] # Whether to rescale the data | ||
NULL_VAL = regular_settings['NULL_VAL'] # Null value in the data | ||
# Model architecture and parameters | ||
MODEL_ARCH = iTransformer | ||
NUM_NODES = 862 | ||
MODEL_PARAM = { | ||
"enc_in": NUM_NODES, # num nodes | ||
"dec_in": NUM_NODES, | ||
"c_out": NUM_NODES, | ||
"seq_len": INPUT_LEN, | ||
"label_len": INPUT_LEN/2, # start token length used in decoder | ||
"pred_len": OUTPUT_LEN, # prediction sequence length | ||
"factor": 3, # attn factor | ||
"p_hidden_dims": [128, 128], | ||
"p_hidden_layers": 2, | ||
"d_model": 512, | ||
"moving_avg": 25, # window size of moving average. This is a CRUCIAL hyper-parameter. | ||
"n_heads": 8, | ||
"e_layers": 4, # num of encoder layers | ||
"d_layers": 1, # num of decoder layers | ||
"d_ff": 512, | ||
"distil": True, | ||
"sigma" : 0.2, | ||
"dropout": 0.1, | ||
"freq": 'h', | ||
"use_norm" : True, | ||
"output_attention": False, | ||
"embed": "timeF", # [timeF, fixed, learned] | ||
"activation": "gelu", | ||
"num_time_features": 4, # number of used time features | ||
"time_of_day_size": 24, | ||
"day_of_week_size": 7, | ||
"day_of_month_size": 31, | ||
"day_of_year_size": 366 | ||
} | ||
NUM_EPOCHS = 20 | ||
|
||
############################## General Configuration ############################## | ||
CFG = EasyDict() | ||
# General settings | ||
CFG.DESCRIPTION = 'An Example Config' | ||
CFG.GPU_NUM = 1 # Number of GPUs to use (0 for CPU mode) | ||
# Runner | ||
CFG.RUNNER = SimpleTimeSeriesForecastingRunner | ||
|
||
############################## Dataset Configuration ############################## | ||
CFG.DATASET = EasyDict() | ||
# Dataset settings | ||
CFG.DATASET.NAME = DATA_NAME | ||
CFG.DATASET.TYPE = TimeSeriesForecastingDataset | ||
CFG.DATASET.PARAM = EasyDict({ | ||
'dataset_name': DATA_NAME, | ||
'train_val_test_ratio': TRAIN_VAL_TEST_RATIO, | ||
'input_len': INPUT_LEN, | ||
'output_len': OUTPUT_LEN, | ||
# 'mode' is automatically set by the runner | ||
}) | ||
|
||
############################## Scaler Configuration ############################## | ||
CFG.SCALER = EasyDict() | ||
# Scaler settings | ||
CFG.SCALER.TYPE = ZScoreScaler # Scaler class | ||
CFG.SCALER.PARAM = EasyDict({ | ||
'dataset_name': DATA_NAME, | ||
'train_ratio': TRAIN_VAL_TEST_RATIO[0], | ||
'norm_each_channel': NORM_EACH_CHANNEL, | ||
'rescale': RESCALE, | ||
}) | ||
|
||
############################## Model Configuration ############################## | ||
CFG.MODEL = EasyDict() | ||
# Model settings | ||
CFG.MODEL.NAME = MODEL_ARCH.__name__ | ||
CFG.MODEL.ARCH = MODEL_ARCH | ||
CFG.MODEL.PARAM = MODEL_PARAM | ||
CFG.MODEL.FORWARD_FEATURES = [0, 1, 2, 3, 4] | ||
CFG.MODEL.TARGET_FEATURES = [0] | ||
|
||
############################## Metrics Configuration ############################## | ||
|
||
CFG.METRICS = EasyDict() | ||
# Metrics settings | ||
CFG.METRICS.FUNCS = EasyDict({ | ||
'MAE': masked_mae, | ||
'MSE': masked_mse, | ||
'RMSE': masked_rmse, | ||
'MAPE': masked_mape | ||
}) | ||
CFG.METRICS.TARGET = 'MSE' | ||
CFG.METRICS.NULL_VAL = NULL_VAL | ||
|
||
############################## Training Configuration ############################## | ||
CFG.TRAIN = EasyDict() | ||
CFG.TRAIN.NUM_EPOCHS = NUM_EPOCHS | ||
CFG.TRAIN.CKPT_SAVE_DIR = os.path.join( | ||
'checkpoints', | ||
MODEL_ARCH.__name__, | ||
'_'.join([DATA_NAME, str(CFG.TRAIN.NUM_EPOCHS), str(INPUT_LEN), str(OUTPUT_LEN)]) | ||
) | ||
CFG.TRAIN.LOSS = masked_mae | ||
# Optimizer settings | ||
CFG.TRAIN.OPTIM = EasyDict() | ||
CFG.TRAIN.OPTIM.TYPE = "Adam" | ||
CFG.TRAIN.OPTIM.PARAM = { | ||
"lr": 0.001, | ||
} | ||
# Learning rate scheduler settings | ||
CFG.TRAIN.LR_SCHEDULER = EasyDict() | ||
CFG.TRAIN.LR_SCHEDULER.TYPE = "MultiStepLR" | ||
CFG.TRAIN.LR_SCHEDULER.PARAM = { | ||
"milestones": [5, 10], | ||
"gamma": 0.5 | ||
} | ||
CFG.TRAIN.CLIP_GRAD_PARAM = { | ||
'max_norm': 5.0 | ||
} | ||
# Train data loader settings | ||
CFG.TRAIN.DATA = EasyDict() | ||
CFG.TRAIN.DATA.BATCH_SIZE = 32 | ||
CFG.TRAIN.DATA.SHUFFLE = True | ||
|
||
############################## Validation Configuration ############################## | ||
CFG.VAL = EasyDict() | ||
CFG.VAL.INTERVAL = 1 | ||
CFG.VAL.DATA = EasyDict() | ||
CFG.VAL.DATA.BATCH_SIZE = 32 | ||
|
||
############################## Test Configuration ############################## | ||
CFG.TEST = EasyDict() | ||
CFG.TEST.INTERVAL = 1 | ||
CFG.TEST.DATA = EasyDict() | ||
CFG.TEST.DATA.BATCH_SIZE = 32 | ||
|
||
############################## Evaluation Configuration ############################## | ||
|
||
CFG.EVAL = EasyDict() | ||
|
||
# Evaluation parameters | ||
CFG.EVAL.USE_GPU = False # Whether to use GPU for evaluation. Default: True |