-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
76 lines (67 loc) · 2.07 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# -*- coding:utf-8 -*-
from sklearn import linear_model
import matplotlib as mpl
import matplotlib.pyplot as pl
import getdata
import numpy as np
from qhutil import *
# 绘图设置
# 设置字体
mpl.rcParams['font.sans-serif'] = ['SimHei']
# 用来正常显示负号
mpl.rcParams['axes.unicode_minus'] = False
# 设置坐标轴刻度显示大小
mpl.rc('xtick', labelsize=16)
mpl.rc('ytick', labelsize=16)
# 设置绘图风格
pl.style.use('fivethirtyeight')
# 岭回归 根据水位和日期
def fit(point = 'C4-A22-PL-01', day = '2016-07-16', delta = -15):
start = addDay(day=day, delta=delta)
end = day
dt,val = getdata.getDataByPoint(point=point, start=start, end=end)
print dt
# 水位数据
wl = []
for i in range(-delta):
wl.append(getdata.getWLByDay(addDay(start,i)))
dt.append(day)
x = [getTimestampByStr(d) for d in dt]
x = np.array(norm(x))
fitdatas = [[x[i],wl[i]] for i in range(len(wl))]
clf = linear_model.Lasso(alpha = 0.6)
clf.fit(fitdatas, val)
realVal = getdata.getDataByDay(point,day)
predict = clf.predict(np.array([x[-1],getdata.getWLByDay(day)]).reshape((1,-1)))[0]
print '======================'
print 'realVal:', realVal
print 'predict:', predict
print 'aberror:', '%.4f' % np.abs(realVal-predict)
print 'errrate:', '%.4f' % errrate(predict, realVal)
print '======================'
return predict
def fitday(day = '2016-04-01', point='C4-A22-IP-01', period=15):
res = []
t0 = time.clock()
try:
dt,val = getdata.getDataByPoint(point=point, start=day, end=addDay(day,period))
except:
return
for d in dt:
res.append(fit(point=point ,day=d))
print u'需要时间','%.4f' % (time.clock() - t0), 's'
fig = pl.figure(figsize=(25, 20))
ax = fig.add_subplot(111)
xticks = range(0,len(dt),len(dt)/10+1)
xticklabels = [dt[i] for i in xticks]
ax.set_xticks(xticks)
ax.set_xticklabels(xticklabels, rotation=15)
ax.set_xlabel(u'日期')
ax.set_ylabel(u'测值')
pl.plot(val,label=u"真实数据")
pl.plot(res,label=u"拟合数据")
pl.title(point)
pl.legend(loc=0)
pl.show()
if __name__ == '__main__':
fitday(day='2014-07-01', period=35)