forked from wangshub/RL-Stock
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
89 lines (68 loc) · 2.47 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import os
import pickle
import pandas as pd
from stable_baselines.common.policies import MlpPolicy
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines import PPO2
from rlenv.StockTradingEnv0 import StockTradingEnv
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
font = fm.FontProperties(fname='font/wqy-microhei.ttc')
# plt.rc('font', family='Source Han Sans CN')
plt.rcParams['axes.unicode_minus'] = False
def stock_trade(stock_file):
day_profits = []
df = pd.read_csv(stock_file)
df = df.sort_values('date')
# The algorithms require a vectorized environment to run
env = DummyVecEnv([lambda: StockTradingEnv(df)])
model = PPO2(MlpPolicy, env, verbose=0, tensorboard_log='./log')
model.learn(total_timesteps=int(1e4))
df_test = pd.read_csv(stock_file.replace('train', 'test'))
env = DummyVecEnv([lambda: StockTradingEnv(df_test)])
obs = env.reset()
for i in range(len(df_test) - 1):
action, _states = model.predict(obs)
obs, rewards, done, info = env.step(action)
profit = env.render()
day_profits.append(profit)
if done:
break
return day_profits
def find_file(path, name):
# print(path, name)
for root, dirs, files in os.walk(path):
for fname in files:
if name in fname:
return os.path.join(root, fname)
def test_a_stock_trade(stock_code):
stock_file = find_file('./stockdata/train', str(stock_code))
daily_profits = stock_trade(stock_file)
fig, ax = plt.subplots()
ax.plot(daily_profits, '-o', label=stock_code, marker='o', ms=10, alpha=0.7, mfc='orange')
ax.grid()
plt.xlabel('step')
plt.ylabel('profit')
ax.legend(prop=font)
# plt.show()
plt.savefig(f'./img/{stock_code}.png')
def multi_stock_trade():
start_code = 600000
max_num = 3000
group_result = []
for code in range(start_code, start_code + max_num):
stock_file = find_file('./stockdata/train', str(code))
if stock_file:
try:
profits = stock_trade(stock_file)
group_result.append(profits)
except Exception as err:
print(err)
with open(f'code-{start_code}-{start_code + max_num}.pkl', 'wb') as f:
pickle.dump(group_result, f)
if __name__ == '__main__':
# multi_stock_trade()
test_a_stock_trade('sh.600036')
# ret = find_file('./stockdata/train', '600036')
# print(ret)