-
Notifications
You must be signed in to change notification settings - Fork 124
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
251 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,251 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import os\n", | ||
"import sys\n", | ||
"import math\n", | ||
"import torch\n", | ||
"\n", | ||
"PROJECT_DIR = os.path.abspath(os.path.abspath('') + \"/../..\")\n", | ||
"os.chdir(PROJECT_DIR)\n", | ||
"\n", | ||
"import numpy as np\n", | ||
"from tqdm import tqdm\n", | ||
"from basicts.data import TimeSeriesForecastingDataset\n", | ||
"from basicts.utils import get_regular_settings\n", | ||
"from basicts.scaler import ZScoreScaler\n", | ||
"\n", | ||
"\n", | ||
"metric = \"cosine\" # metric used to calculate the similarity.\n", | ||
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", | ||
"\n", | ||
"DATA_NAME = \"METR-LA\"\n", | ||
"DATA_NAME = \"ETTh1\"\n", | ||
"BATCH_SIZE = 8\n", | ||
"regular_settings = get_regular_settings(DATA_NAME)\n", | ||
"INPUT_LEN = regular_settings['INPUT_LEN'] # Length of input sequence\n", | ||
"OUTPUT_LEN = regular_settings['OUTPUT_LEN'] # Length of output sequence\n", | ||
"TRAIN_VAL_TEST_RATIO = regular_settings['TRAIN_VAL_TEST_RATIO'] # Train/Validation/Test split ratios\n", | ||
"RESCALE = regular_settings['RESCALE'] # Whether to rescale the data\n", | ||
"NULL_VAL = regular_settings['NULL_VAL'] # Null value in the data\n", | ||
"NORM_EACH_CHANNEL = regular_settings['NORM_EACH_CHANNEL'] # Whether to normalize each channel\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## utilities" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# similarity computation\n", | ||
"def cosine_similarity(x, y):\n", | ||
" # denominator\n", | ||
" l2_x = torch.norm(x, dim=2, p=2) + 1e-7\n", | ||
" l2_y = torch.norm(y, dim=2, p=2) + 1e-7\n", | ||
" l2_n = torch.matmul(l2_x.unsqueeze(dim=2), l2_y.unsqueeze(dim=2).transpose(1, 2))\n", | ||
" # numerator\n", | ||
" l2_d = torch.matmul(x, y.transpose(1, 2))\n", | ||
" return l2_d / l2_n\n", | ||
"\n", | ||
"def get_similarity_matrix(data, metric):\n", | ||
" if metric == \"cosine\":\n", | ||
" sim = cosine_similarity(data, data)\n", | ||
" elif metric == \"mse\":\n", | ||
" sim = torch.cdist(data, data, p=2)\n", | ||
" elif metric == \"mae\":\n", | ||
" sim = torch.cdist(data, data, p=1)\n", | ||
" else:\n", | ||
" raise NotImplementedError\n", | ||
" return sim" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"dataset_param = {\n", | ||
" 'dataset_name': DATA_NAME,\n", | ||
" 'train_val_test_ratio': TRAIN_VAL_TEST_RATIO,\n", | ||
" 'input_len': INPUT_LEN,\n", | ||
" 'output_len': OUTPUT_LEN,\n", | ||
"}\n", | ||
"# get dataloader\n", | ||
"dataset = TimeSeriesForecastingDataset(**dataset_param, mode='train')\n", | ||
"# the whole training data\n", | ||
"dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=3)\n", | ||
"\n", | ||
"scaler_param = {\n", | ||
" 'dataset_name': DATA_NAME,\n", | ||
" 'train_ratio': TRAIN_VAL_TEST_RATIO[0],\n", | ||
" 'norm_each_channel': NORM_EACH_CHANNEL,\n", | ||
" 'rescale': RESCALE,\n", | ||
"}\n", | ||
"scaler = ZScoreScaler(**scaler_param)\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Generate Similarity Matrix" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"100%|██████████| 997/997 [00:02<00:00, 412.47it/s]\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# get similarity matrices\n", | ||
"\n", | ||
"# inference pipeline for a given dataloader\n", | ||
"history_adjs_all = []\n", | ||
"future_adjs_all = []\n", | ||
"def inference(dataloader):\n", | ||
" for batch in tqdm(dataloader):\n", | ||
" future_data, history_data = batch['target'], batch['inputs']\n", | ||
" future_data = scaler.transform(future_data)\n", | ||
" history_data = scaler.transform(history_data)\n", | ||
" history_data = history_data[..., 0].transpose(1, 2) # batch_size, num_nodes, history_seq_len\n", | ||
" future_data = future_data[..., 0].transpose(1, 2) # batch_size, num_nodes, future_seq_len\n", | ||
" history_adjs = get_similarity_matrix(history_data, metric) # batch_size, num_nodes, num_nodes\n", | ||
" future_adjs = get_similarity_matrix(future_data, metric) # batch_size, num_nodes, num_nodes\n", | ||
" history_adjs_all.append(history_adjs)\n", | ||
" future_adjs_all.append(future_adjs)\n", | ||
"# get similarity matrices\n", | ||
"# for mode in [\"valid\"]:\n", | ||
"for mode in [\"train\"]:\n", | ||
" inference(dataloader)\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"torch.Size([7969, 7, 7])\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# get spatial indistinguishability ratio\n", | ||
"history_similarity = torch.cat(history_adjs_all, dim=0).detach().cpu() # num_samples, num_modes, num_nodes\n", | ||
"future_similarity = torch.cat(future_adjs_all, dim=0).detach().cpu() # num_samples, num_modes, num_nodes\n", | ||
"L, N, N = future_similarity.shape\n", | ||
"print(future_similarity.shape)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Get Spatial Indistinguishability Ratio" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"e_u = 0.9\n", | ||
"e_l = 0.4\n", | ||
"\n", | ||
"history_similarity_filtered = torch.where(history_similarity > e_u, torch.ones_like(history_similarity), torch.zeros_like(history_similarity))\n", | ||
"future_similarity_filtered = torch.where(future_similarity < e_l, torch.ones_like(future_similarity), torch.zeros_like(future_similarity))\n", | ||
"overlap = history_similarity_filtered * future_similarity_filtered\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"tensor(3.8568)\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# overlap ratio\n", | ||
"overlap_ratio = overlap.sum() / (L * N * N)\n", | ||
"print(overlap_ratio * 1000)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 8, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"tensor(15.7748)\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# indistinguishability ratio\n", | ||
"indistinguishability_ratio = overlap.sum() / history_similarity_filtered.sum()\n", | ||
"print(indistinguishability_ratio * 1000)\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "BasicTS", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.11" | ||
}, | ||
"orig_nbformat": 4 | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |