-
Notifications
You must be signed in to change notification settings - Fork 6
/
inferenceColab.py
56 lines (52 loc) · 2.32 KB
/
inferenceColab.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import argparse
import matplotlib.pyplot as plt
from model.model import ViTime
import numpy as np
import torch
from scipy import interpolate
def interpolate_to_512(original_sequence):
n = len(original_sequence)
x_original = np.linspace(0, 1, n)
x_interpolated = np.linspace(0, 1, 512)
f = interpolate.interp1d(x_original, original_sequence)
interpolated_sequence = f(x_interpolated)
return interpolated_sequence
def inverse_interpolate(processed_sequence, original_length):
processed_length = len(processed_sequence)
z = int(original_length * 720 / 512)
x_processed = np.linspace(0, 1, processed_length)
x_inverse = np.linspace(0, 1, z)
f_inverse = interpolate.interp1d(x_processed, processed_sequence)
inverse_interpolated_sequence = f_inverse(x_inverse)
return inverse_interpolated_sequence
def main(modelpath, savepath):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
checkpoint = torch.load(modelpath, map_location=device)
args = checkpoint['args']
args.device = device
args.flag = 'test'
# Set upscaling parameters
args.upscal = True # True: max input length = 512, max prediction length = 720
# False: max input length = 1024, max prediction length = 1440
model = ViTime(args=args)
model.load_state_dict(checkpoint['model'])
model.to(device)
model.eval()
# Example data
xData=np.sin(np.arange(512)/10)+np.sin(np.arange(512)/5+50)+np.cos(np.arange(512)+50)
interpolated_sequence = interpolate_to_512(xData)
args.realInputLength = len(interpolated_sequence)
yp = model.inference(interpolated_sequence).flatten()
yp=inverse_interpolate(yp, len(xData)).flatten()
# Plot results
plt.plot(np.concatenate([xData, yp.flatten()], axis=0), label='Prediction')
plt.plot(xData, label='Input Sequence')
plt.legend()
plt.savefig(savepath) # 保存图形到指定路径
plt.show()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='ViTime model inference')
parser.add_argument('--modelpath', type=str, required=True, help='Path to the model checkpoint file')
parser.add_argument('--savepath', type=str, default='plot.png', help='Path to save the plot image (default: plot.png)')
args = parser.parse_args()
main(args.modelpath, args.savepath)