-
Notifications
You must be signed in to change notification settings - Fork 122
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: 🎸 add a LTSF baseline SOFTS (added by @superarthurlx)
- Loading branch information
Showing
6 changed files
with
307 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
import os | ||
import sys | ||
from easydict import EasyDict | ||
sys.path.append(os.path.abspath(__file__ + '/../../..')) | ||
from basicts.metrics import masked_mae, masked_mse, masked_mape, masked_rmse | ||
from basicts.data import TimeSeriesForecastingDataset | ||
from basicts.runners import SimpleTimeSeriesForecastingRunner | ||
from basicts.scaler import ZScoreScaler | ||
from basicts.utils import get_regular_settings | ||
|
||
from .arch import SOFTS | ||
|
||
############################## Hot Parameters ############################## | ||
# Dataset & Metrics configuration | ||
DATA_NAME = 'ETTh1' # Dataset name | ||
regular_settings = get_regular_settings(DATA_NAME) | ||
INPUT_LEN = regular_settings['INPUT_LEN'] # 336, better performance | ||
OUTPUT_LEN = regular_settings['OUTPUT_LEN'] # Length of output sequence | ||
TRAIN_VAL_TEST_RATIO = regular_settings['TRAIN_VAL_TEST_RATIO'] # Train/Validation/Test split ratios | ||
NORM_EACH_CHANNEL = regular_settings['NORM_EACH_CHANNEL'] # Whether to normalize each channel of the data | ||
RESCALE = regular_settings['RESCALE'] # Whether to rescale the data | ||
NULL_VAL = regular_settings['NULL_VAL'] # Null value in the data | ||
# Model architecture and parameters | ||
MODEL_ARCH = SOFTS | ||
NUM_NODES = 7 | ||
MODEL_PARAM = { | ||
"enc_in": NUM_NODES, # num nodes | ||
"dec_in": NUM_NODES, | ||
"c_out": NUM_NODES, | ||
"seq_len": INPUT_LEN, | ||
"pred_len": OUTPUT_LEN, # prediction sequence length | ||
"e_layers": 2, # num of encoder layers | ||
"d_model": 256, | ||
"d_core": 256, | ||
"d_ff": 512, | ||
"dropout": 0.0, | ||
"use_norm" : True, | ||
"activation": "gelu", | ||
"num_time_features": 4, # number of used time features | ||
"time_of_day_size": 24, | ||
"day_of_week_size": 7, | ||
"day_of_month_size": 31, | ||
"day_of_year_size": 366 | ||
} | ||
NUM_EPOCHS = 50 | ||
|
||
############################## General Configuration ############################## | ||
CFG = EasyDict() | ||
# General settings | ||
CFG.DESCRIPTION = 'An Example Config' | ||
CFG.GPU_NUM = 1 # Number of GPUs to use (0 for CPU mode) | ||
# Runner | ||
CFG.RUNNER = SimpleTimeSeriesForecastingRunner | ||
|
||
############################## Dataset Configuration ############################## | ||
CFG.DATASET = EasyDict() | ||
# Dataset settings | ||
CFG.DATASET.NAME = DATA_NAME | ||
CFG.DATASET.TYPE = TimeSeriesForecastingDataset | ||
CFG.DATASET.PARAM = EasyDict({ | ||
'dataset_name': DATA_NAME, | ||
'train_val_test_ratio': TRAIN_VAL_TEST_RATIO, | ||
'input_len': INPUT_LEN, | ||
'output_len': OUTPUT_LEN, | ||
# 'mode' is automatically set by the runner | ||
}) | ||
|
||
############################## Scaler Configuration ############################## | ||
CFG.SCALER = EasyDict() | ||
# Scaler settings | ||
CFG.SCALER.TYPE = ZScoreScaler # Scaler class | ||
CFG.SCALER.PARAM = EasyDict({ | ||
'dataset_name': DATA_NAME, | ||
'train_ratio': TRAIN_VAL_TEST_RATIO[0], | ||
'norm_each_channel': NORM_EACH_CHANNEL, | ||
'rescale': RESCALE, | ||
}) | ||
|
||
############################## Model Configuration ############################## | ||
CFG.MODEL = EasyDict() | ||
# Model settings | ||
CFG.MODEL.NAME = MODEL_ARCH.__name__ | ||
CFG.MODEL.ARCH = MODEL_ARCH | ||
CFG.MODEL.PARAM = MODEL_PARAM | ||
CFG.MODEL.FORWARD_FEATURES = [0, 1, 2, 3, 4] | ||
CFG.MODEL.TARGET_FEATURES = [0] | ||
|
||
############################## Metrics Configuration ############################## | ||
|
||
CFG.METRICS = EasyDict() | ||
# Metrics settings | ||
CFG.METRICS.FUNCS = EasyDict({ | ||
'MAE': masked_mae, | ||
'MSE': masked_mse, | ||
'RMSE': masked_rmse, | ||
'MAPE': masked_mape | ||
}) | ||
CFG.METRICS.TARGET = 'MAE' | ||
CFG.METRICS.NULL_VAL = NULL_VAL | ||
|
||
############################## Training Configuration ############################## | ||
CFG.TRAIN = EasyDict() | ||
CFG.TRAIN.NUM_EPOCHS = NUM_EPOCHS | ||
CFG.TRAIN.CKPT_SAVE_DIR = os.path.join( | ||
'checkpoints', | ||
MODEL_ARCH.__name__, | ||
'_'.join([DATA_NAME, str(CFG.TRAIN.NUM_EPOCHS), str(INPUT_LEN), str(OUTPUT_LEN)]) | ||
) | ||
CFG.TRAIN.LOSS = masked_mae | ||
# Optimizer settings | ||
CFG.TRAIN.OPTIM = EasyDict() | ||
CFG.TRAIN.OPTIM.TYPE = "Adam" | ||
CFG.TRAIN.OPTIM.PARAM = { | ||
"lr": 0.0003, | ||
} | ||
# Learning rate scheduler settings | ||
CFG.TRAIN.LR_SCHEDULER = EasyDict() | ||
CFG.TRAIN.LR_SCHEDULER.TYPE = "MultiStepLR" | ||
CFG.TRAIN.LR_SCHEDULER.PARAM = { | ||
"milestones": [1, 25, 50], | ||
"gamma": 0.5 | ||
} | ||
CFG.TRAIN.CLIP_GRAD_PARAM = { | ||
'max_norm': 5.0 | ||
} | ||
# Train data loader settings | ||
CFG.TRAIN.DATA = EasyDict() | ||
CFG.TRAIN.DATA.BATCH_SIZE = 64 | ||
CFG.TRAIN.DATA.SHUFFLE = True | ||
CFG.TRAIN.EARLY_STOPPING_PATIENCE = 10 | ||
|
||
############################## Validation Configuration ############################## | ||
CFG.VAL = EasyDict() | ||
CFG.VAL.INTERVAL = 1 | ||
CFG.VAL.DATA = EasyDict() | ||
CFG.VAL.DATA.BATCH_SIZE = 64 | ||
|
||
############################## Test Configuration ############################## | ||
CFG.TEST = EasyDict() | ||
CFG.TEST.INTERVAL = 1 | ||
CFG.TEST.DATA = EasyDict() | ||
CFG.TEST.DATA.BATCH_SIZE = 64 | ||
|
||
############################## Evaluation Configuration ############################## | ||
|
||
CFG.EVAL = EasyDict() | ||
|
||
# Evaluation parameters | ||
CFG.EVAL.USE_GPU = True # Whether to use GPU for evaluation. Default: True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
import os | ||
import sys | ||
from easydict import EasyDict | ||
sys.path.append(os.path.abspath(__file__ + '/../../..')) | ||
from basicts.metrics import masked_mae, masked_mse, masked_mape, masked_rmse | ||
from basicts.data import TimeSeriesForecastingDataset | ||
from basicts.runners import SimpleTimeSeriesForecastingRunner | ||
from basicts.scaler import ZScoreScaler | ||
from basicts.utils import get_regular_settings | ||
|
||
from .arch import SOFTS | ||
|
||
############################## Hot Parameters ############################## | ||
# Dataset & Metrics configuration | ||
DATA_NAME = 'ETTh2' # Dataset name | ||
regular_settings = get_regular_settings(DATA_NAME) | ||
# INPUT_LEN = regular_settings['INPUT_LEN'] # Length of input sequence | ||
INPUT_LEN = 192 # better performance | ||
OUTPUT_LEN = regular_settings['OUTPUT_LEN'] # Length of output sequence | ||
TRAIN_VAL_TEST_RATIO = regular_settings['TRAIN_VAL_TEST_RATIO'] # Train/Validation/Test split ratios | ||
NORM_EACH_CHANNEL = regular_settings['NORM_EACH_CHANNEL'] # Whether to normalize each channel of the data | ||
RESCALE = regular_settings['RESCALE'] # Whether to rescale the data | ||
NULL_VAL = regular_settings['NULL_VAL'] # Null value in the data | ||
# Model architecture and parameters | ||
MODEL_ARCH = SOFTS | ||
NUM_NODES = 7 | ||
MODEL_PARAM = { | ||
"enc_in": NUM_NODES, # num nodes | ||
"dec_in": NUM_NODES, | ||
"c_out": NUM_NODES, | ||
"seq_len": INPUT_LEN, | ||
"pred_len": OUTPUT_LEN, # prediction sequence length | ||
"e_layers": 2, # num of encoder layers | ||
"d_model": 128, | ||
"d_core": 64, | ||
"d_ff": 128, | ||
"dropout": 0.0, | ||
"use_norm" : True, | ||
"activation": "gelu", | ||
"num_time_features": 4, # number of used time features | ||
"time_of_day_size": 24, | ||
"day_of_week_size": 7, | ||
"day_of_month_size": 31, | ||
"day_of_year_size": 366 | ||
} | ||
NUM_EPOCHS = 20 | ||
|
||
############################## General Configuration ############################## | ||
CFG = EasyDict() | ||
# General settings | ||
CFG.DESCRIPTION = 'An Example Config' | ||
CFG.GPU_NUM = 1 # Number of GPUs to use (0 for CPU mode) | ||
# Runner | ||
CFG.RUNNER = SimpleTimeSeriesForecastingRunner | ||
|
||
CFG.ENV = EasyDict() # Environment settings. Default: None | ||
CFG.ENV.SEED = 2024 # Random seed. Default: None | ||
|
||
############################## Dataset Configuration ############################## | ||
CFG.DATASET = EasyDict() | ||
# Dataset settings | ||
CFG.DATASET.NAME = DATA_NAME | ||
CFG.DATASET.TYPE = TimeSeriesForecastingDataset | ||
CFG.DATASET.PARAM = EasyDict({ | ||
'dataset_name': DATA_NAME, | ||
'train_val_test_ratio': TRAIN_VAL_TEST_RATIO, | ||
'input_len': INPUT_LEN, | ||
'output_len': OUTPUT_LEN, | ||
# 'mode' is automatically set by the runner | ||
}) | ||
|
||
############################## Scaler Configuration ############################## | ||
CFG.SCALER = EasyDict() | ||
# Scaler settings | ||
CFG.SCALER.TYPE = ZScoreScaler # Scaler class | ||
CFG.SCALER.PARAM = EasyDict({ | ||
'dataset_name': DATA_NAME, | ||
'train_ratio': TRAIN_VAL_TEST_RATIO[0], | ||
'norm_each_channel': NORM_EACH_CHANNEL, | ||
'rescale': RESCALE, | ||
}) | ||
|
||
############################## Model Configuration ############################## | ||
CFG.MODEL = EasyDict() | ||
# Model settings | ||
CFG.MODEL.NAME = MODEL_ARCH.__name__ | ||
CFG.MODEL.ARCH = MODEL_ARCH | ||
CFG.MODEL.PARAM = MODEL_PARAM | ||
CFG.MODEL.FORWARD_FEATURES = [0, 1, 2, 3, 4] | ||
CFG.MODEL.TARGET_FEATURES = [0] | ||
|
||
############################## Metrics Configuration ############################## | ||
|
||
CFG.METRICS = EasyDict() | ||
# Metrics settings | ||
CFG.METRICS.FUNCS = EasyDict({ | ||
'MAE': masked_mae, | ||
'MSE': masked_mse, | ||
'RMSE': masked_rmse, | ||
'MAPE': masked_mape | ||
}) | ||
CFG.METRICS.TARGET = 'MAE' | ||
CFG.METRICS.NULL_VAL = NULL_VAL | ||
|
||
############################## Training Configuration ############################## | ||
CFG.TRAIN = EasyDict() | ||
CFG.TRAIN.NUM_EPOCHS = NUM_EPOCHS | ||
CFG.TRAIN.CKPT_SAVE_DIR = os.path.join( | ||
'checkpoints', | ||
MODEL_ARCH.__name__, | ||
'_'.join([DATA_NAME, str(CFG.TRAIN.NUM_EPOCHS), str(INPUT_LEN), str(OUTPUT_LEN)]) | ||
) | ||
CFG.TRAIN.LOSS = masked_mse | ||
# Optimizer settings | ||
CFG.TRAIN.OPTIM = EasyDict() | ||
CFG.TRAIN.OPTIM.TYPE = "Adam" | ||
CFG.TRAIN.OPTIM.PARAM = { | ||
"lr": 0.0003, | ||
} | ||
# Learning rate scheduler settings | ||
CFG.TRAIN.LR_SCHEDULER = EasyDict() | ||
CFG.TRAIN.LR_SCHEDULER.TYPE = "MultiStepLR" | ||
CFG.TRAIN.LR_SCHEDULER.PARAM = { | ||
"milestones": [1, 25, 50], | ||
"gamma": 0.5 | ||
} | ||
CFG.TRAIN.CLIP_GRAD_PARAM = { | ||
'max_norm': 5.0 | ||
} | ||
# Train data loader settings | ||
CFG.TRAIN.DATA = EasyDict() | ||
CFG.TRAIN.DATA.BATCH_SIZE = 32 | ||
CFG.TRAIN.DATA.SHUFFLE = True | ||
CFG.TRAIN.EARLY_STOPPING_PATIENCE = 10 | ||
|
||
############################## Validation Configuration ############################## | ||
CFG.VAL = EasyDict() | ||
CFG.VAL.INTERVAL = 1 | ||
CFG.VAL.DATA = EasyDict() | ||
CFG.VAL.DATA.BATCH_SIZE = 64 | ||
|
||
############################## Test Configuration ############################## | ||
CFG.TEST = EasyDict() | ||
CFG.TEST.INTERVAL = 1 | ||
CFG.TEST.DATA = EasyDict() | ||
CFG.TEST.DATA.BATCH_SIZE = 64 | ||
|
||
############################## Evaluation Configuration ############################## | ||
|
||
CFG.EVAL = EasyDict() | ||
|
||
# Evaluation parameters | ||
CFG.EVAL.USE_GPU = True # Whether to use GPU for evaluation. Default: True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters