Skip to content

Commit

Permalink
Merge branch 'GestaltCogTeam:master' into contribution
Browse files Browse the repository at this point in the history
  • Loading branch information
superarthurlx authored Jan 2, 2025
2 parents 59a0f95 + ab8d29d commit 489884e
Show file tree
Hide file tree
Showing 5 changed files with 259 additions and 34 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ BasicTS is developed based on [EasyTorch](https://github.com/cnstark/easytorch),
We invite you to join our official community to access comprehensive technical support.
Official Discord Server: [Click here to join our Discord community](https://discord.gg/chxdaxG4)
Official Discord Server: [Click here to join our Discord community](https://discord.gg/UWFP7b3b7H)
Official WeChat Group:
Expand Down
2 changes: 1 addition & 1 deletion README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -291,4 +291,4 @@ BasicTS 是基于 [EasyTorch](https://github.com/cnstark/easytorch) 开发的,
![wechat](assets/BasicTS-wechat-cn.jpg)
官方Discord频道: [点击加入我们的 Discord 社区](https://discord.gg/chxdaxG4)
官方Discord频道: [点击加入我们的 Discord 社区](https://discord.gg/UWFP7b3b7H)
32 changes: 3 additions & 29 deletions baselines/MTGNN/runner/mtgnn_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,36 +17,9 @@ def __init__(self, cfg: dict):
self.num_split = cfg.TRAIN.CUSTOM.NUM_SPLIT
self.perm = None

def select_input_features(self, data: torch.Tensor) -> torch.Tensor:
"""Select input features.
Args:
data (torch.Tensor): input history data, shape [B, L, N, C]
Returns:
torch.Tensor: reshaped data
"""

# select feature using self.forward_features
if self.forward_features is not None:
data = data[:, :, :, self.forward_features]
return data

def select_target_features(self, data: torch.Tensor) -> torch.Tensor:
"""Select target feature
Args:
data (torch.Tensor): prediction of the model with arbitrary shape.
Returns:
torch.Tensor: reshaped data with shape [B, L, N, C]
"""

# select feature using self.target_features
data = data[:, :, :, self.target_features]
return data

def forward(self, data: tuple, epoch: int = None, iter_num: int = None, train: bool = True, **kwargs) -> tuple:
data = self.preprocessing(data)

if train:
future_data, history_data, idx = data['target'], data['inputs'], data['idx']
else:
Expand All @@ -68,6 +41,7 @@ def forward(self, data: tuple, epoch: int = None, iter_num: int = None, train: b
model_return["target"] = self.select_target_features(future_data)
assert list(model_return["prediction"].shape)[:3] == [batch_size, seq_len, num_nodes], \
"error shape of the output, edit the forward function to reshape it to [B, L, N, C]"
model_return = self.postprocessing(model_return)
return model_return

def train_iters(self, epoch: int, iter_index: int, data: Union[torch.Tensor, Tuple]) -> torch.Tensor:
Expand Down
6 changes: 3 additions & 3 deletions baselines/STAEformer/arch/staeformer_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,16 +201,16 @@ def forward(self, history_data: torch.Tensor, future_data: torch.Tensor, batch_s
batch_size = x.shape[0]

if self.tod_embedding_dim > 0:
tod = x[..., 1]
tod = x[..., 1] * self.steps_per_day
if self.dow_embedding_dim > 0:
dow = x[..., 2]
dow = x[..., 2] * 7
x = x[..., : self.input_dim]

x = self.input_proj(x) # (batch_size, in_steps, num_nodes, input_embedding_dim)
features = [x]
if self.tod_embedding_dim > 0:
tod_emb = self.tod_embedding(
(tod * self.steps_per_day).long()
tod.long()
) # (batch_size, in_steps, num_nodes, tod_embedding_dim)
features.append(tod_emb)
if self.dow_embedding_dim > 0:
Expand Down
251 changes: 251 additions & 0 deletions scripts/data_visualization/indistinguishability.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import sys\n",
"import math\n",
"import torch\n",
"\n",
"PROJECT_DIR = os.path.abspath(os.path.abspath('') + \"/../..\")\n",
"os.chdir(PROJECT_DIR)\n",
"\n",
"import numpy as np\n",
"from tqdm import tqdm\n",
"from basicts.data import TimeSeriesForecastingDataset\n",
"from basicts.utils import get_regular_settings\n",
"from basicts.scaler import ZScoreScaler\n",
"\n",
"\n",
"metric = \"cosine\" # metric used to calculate the similarity.\n",
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"DATA_NAME = \"METR-LA\"\n",
"DATA_NAME = \"ETTh1\"\n",
"BATCH_SIZE = 8\n",
"regular_settings = get_regular_settings(DATA_NAME)\n",
"INPUT_LEN = regular_settings['INPUT_LEN'] # Length of input sequence\n",
"OUTPUT_LEN = regular_settings['OUTPUT_LEN'] # Length of output sequence\n",
"TRAIN_VAL_TEST_RATIO = regular_settings['TRAIN_VAL_TEST_RATIO'] # Train/Validation/Test split ratios\n",
"RESCALE = regular_settings['RESCALE'] # Whether to rescale the data\n",
"NULL_VAL = regular_settings['NULL_VAL'] # Null value in the data\n",
"NORM_EACH_CHANNEL = regular_settings['NORM_EACH_CHANNEL'] # Whether to normalize each channel\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## utilities"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# similarity computation\n",
"def cosine_similarity(x, y):\n",
" # denominator\n",
" l2_x = torch.norm(x, dim=2, p=2) + 1e-7\n",
" l2_y = torch.norm(y, dim=2, p=2) + 1e-7\n",
" l2_n = torch.matmul(l2_x.unsqueeze(dim=2), l2_y.unsqueeze(dim=2).transpose(1, 2))\n",
" # numerator\n",
" l2_d = torch.matmul(x, y.transpose(1, 2))\n",
" return l2_d / l2_n\n",
"\n",
"def get_similarity_matrix(data, metric):\n",
" if metric == \"cosine\":\n",
" sim = cosine_similarity(data, data)\n",
" elif metric == \"mse\":\n",
" sim = torch.cdist(data, data, p=2)\n",
" elif metric == \"mae\":\n",
" sim = torch.cdist(data, data, p=1)\n",
" else:\n",
" raise NotImplementedError\n",
" return sim"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"dataset_param = {\n",
" 'dataset_name': DATA_NAME,\n",
" 'train_val_test_ratio': TRAIN_VAL_TEST_RATIO,\n",
" 'input_len': INPUT_LEN,\n",
" 'output_len': OUTPUT_LEN,\n",
"}\n",
"# get dataloader\n",
"dataset = TimeSeriesForecastingDataset(**dataset_param, mode='train')\n",
"# the whole training data\n",
"dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=3)\n",
"\n",
"scaler_param = {\n",
" 'dataset_name': DATA_NAME,\n",
" 'train_ratio': TRAIN_VAL_TEST_RATIO[0],\n",
" 'norm_each_channel': NORM_EACH_CHANNEL,\n",
" 'rescale': RESCALE,\n",
"}\n",
"scaler = ZScoreScaler(**scaler_param)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Generate Similarity Matrix"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 997/997 [00:02<00:00, 412.47it/s]\n"
]
}
],
"source": [
"# get similarity matrices\n",
"\n",
"# inference pipeline for a given dataloader\n",
"history_adjs_all = []\n",
"future_adjs_all = []\n",
"def inference(dataloader):\n",
" for batch in tqdm(dataloader):\n",
" future_data, history_data = batch['target'], batch['inputs']\n",
" future_data = scaler.transform(future_data)\n",
" history_data = scaler.transform(history_data)\n",
" history_data = history_data[..., 0].transpose(1, 2) # batch_size, num_nodes, history_seq_len\n",
" future_data = future_data[..., 0].transpose(1, 2) # batch_size, num_nodes, future_seq_len\n",
" history_adjs = get_similarity_matrix(history_data, metric) # batch_size, num_nodes, num_nodes\n",
" future_adjs = get_similarity_matrix(future_data, metric) # batch_size, num_nodes, num_nodes\n",
" history_adjs_all.append(history_adjs)\n",
" future_adjs_all.append(future_adjs)\n",
"# get similarity matrices\n",
"# for mode in [\"valid\"]:\n",
"for mode in [\"train\"]:\n",
" inference(dataloader)\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([7969, 7, 7])\n"
]
}
],
"source": [
"# get spatial indistinguishability ratio\n",
"history_similarity = torch.cat(history_adjs_all, dim=0).detach().cpu() # num_samples, num_modes, num_nodes\n",
"future_similarity = torch.cat(future_adjs_all, dim=0).detach().cpu() # num_samples, num_modes, num_nodes\n",
"L, N, N = future_similarity.shape\n",
"print(future_similarity.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Get Spatial Indistinguishability Ratio"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"e_u = 0.9\n",
"e_l = 0.4\n",
"\n",
"history_similarity_filtered = torch.where(history_similarity > e_u, torch.ones_like(history_similarity), torch.zeros_like(history_similarity))\n",
"future_similarity_filtered = torch.where(future_similarity < e_l, torch.ones_like(future_similarity), torch.zeros_like(future_similarity))\n",
"overlap = history_similarity_filtered * future_similarity_filtered\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor(3.8568)\n"
]
}
],
"source": [
"# overlap ratio\n",
"overlap_ratio = overlap.sum() / (L * N * N)\n",
"print(overlap_ratio * 1000)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor(15.7748)\n"
]
}
],
"source": [
"# indistinguishability ratio\n",
"indistinguishability_ratio = overlap.sum() / history_similarity_filtered.sum()\n",
"print(indistinguishability_ratio * 1000)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "BasicTS",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit 489884e

Please sign in to comment.