-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict_lorenz.py
72 lines (58 loc) · 2.7 KB
/
predict_lorenz.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
import argparse
import pdb
import time
import torch
import numpy as np
# TensorBoard
from torch.utils.tensorboard import SummaryWriter
from model import load_model, save_model
from data.loaders import lorenz_loader
from utils.util import linear_alignment
from utils.metrics import compute_R2
def encode_test(args, model):
(train_loader, train_dataset, test_loader, test_dataset, X_dynamics) = lorenz_loader(args, num_workers=args.num_workers)
z, _ = model.model.get_latent_representations(torch.from_numpy(test_dataset.get_full_data()[None,:,:]).to(args.device))
aligned_encoded_signals = linear_alignment(z.detach().to('cpu')[0], X_dynamics)
R2_CPC = compute_R2(aligned_encoded_signals, X_dynamics)
return R2_CPC
def main():
parser = argparse.ArgumentParser(description='Lorenz experiment.')
parser.add_argument('--out_dir', type=str, default="./result/lorenz")
parser.add_argument('--experiment', type=str, default="lorenz")
parser.add_argument('--num_workers', type=int, default=2)
parser.add_argument('--snr_index', type=int, default=0)
# CPC
parser.add_argument('--learning_rate', type=float, default=2.0e-4)
parser.add_argument('--negative_samples', type=int, default=10)
parser.add_argument('--prediction_step', type=int, default=12)
parser.add_argument('--subsample', action="store_true")
# General
parser.add_argument('--genc_input', type=int, default=30)
parser.add_argument('--seed', type=int, default=22)
parser.add_argument('--num_epochs', type=int, default=10)
parser.add_argument('--batch_size', type=int, default=10)
parser.add_argument('--data_input_dir', type=str, default="./datasets/lorenz/lorenz_exploration.hdf5")
parser.add_argument('--data_output_dir', type=str, default=".")
parser.add_argument('--validate', action="store_true")
parser.add_argument('--fp16', action="store_true")
parser.add_argument('--calc_accuracy', action="store_true")
# Reload
parser.add_argument('--start_epoch', type=int, default=0)
parser.add_argument('--model_path', type=str, default="result/lorenz")
parser.add_argument('--model_num', type=int, default=10)
args = parser.parse_args()
# set start time
args.time = time.ctime()
# Device configuration
args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
args.current_epoch = args.start_epoch
# set random seeds
np.random.seed(args.seed)
torch.manual_seed(args.seed)
# load model
model, optimizer = load_model(args, reload_model=True)
R2_CPC = encode_test(args, model.module)
print(R2_CPC)
if __name__ == "__main__":
main()