-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathinference.py
47 lines (45 loc) · 1.82 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import matplotlib.pyplot as plt
from model.model import ViTime
import numpy as np
import torch
from scipy import interpolate
def interpolate_to_512(original_sequence):
n = len(original_sequence)
x_original = np.linspace(0, 1, n)
x_interpolated = np.linspace(0, 1, 512)
f = interpolate.interp1d(x_original, original_sequence)
interpolated_sequence = f(x_interpolated)
return interpolated_sequence
def inverse_interpolate(processed_sequence, original_length):
processed_length = len(processed_sequence)
z = int(original_length * 720 / 512)
x_processed = np.linspace(0, 1, processed_length)
x_inverse = np.linspace(0, 1, z)
f_inverse = interpolate.interp1d(x_processed, processed_sequence)
inverse_interpolated_sequence = f_inverse(x_inverse)
return inverse_interpolated_sequence
deviceNum = 0
torch.cuda.set_device(deviceNum)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
checkpoint = torch.load(r'ViTime_V1.pth', map_location=device)
args = checkpoint['args']
args.device = device
args.flag = 'test'
# Set upscaling parameters
args.upscal = True # True: max input length = 512, max prediction length = 720
# False: max input length = 1024, max prediction length = 1440
model = ViTime(args=args)
model.load_state_dict(checkpoint['model'])
model.to(device)
model.eval()
# Example data
xData=np.sin(np.arange(512)/10)+np.sin(np.arange(512)/5+50)+np.cos(np.arange(512)+50)
interpolated_sequence = interpolate_to_512(xData)
args.realInputLength = len(interpolated_sequence)
yp = model.inference(interpolated_sequence).flatten()
yp=inverse_interpolate(yp, len(xData)).flatten()
# Plot results
plt.plot(np.concatenate([xData,yp.flatten()],axis=0),label='Prediction')
plt.plot(xData,label='Input Sequence')
plt.legend()
plt.show()