-
Notifications
You must be signed in to change notification settings - Fork 344
/
train.py
85 lines (69 loc) · 3.28 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
"""
Script for training Stock Trading Bot.
Usage:
train.py <train-stock> <val-stock> [--strategy=<strategy>]
[--window-size=<window-size>] [--batch-size=<batch-size>]
[--episode-count=<episode-count>] [--model-name=<model-name>]
[--pretrained] [--debug]
Options:
--strategy=<strategy> Q-learning strategy to use for training the network. Options:
`dqn` i.e. Vanilla DQN,
`t-dqn` i.e. DQN with fixed target distribution,
`double-dqn` i.e. DQN with separate network for value estimation. [default: t-dqn]
--window-size=<window-size> Size of the n-day window stock data representation
used as the feature vector. [default: 10]
--batch-size=<batch-size> Number of samples to train on in one mini-batch
during training. [default: 32]
--episode-count=<episode-count> Number of trading episodes to use for training. [default: 50]
--model-name=<model-name> Name of the pretrained model to use. [default: model_debug]
--pretrained Specifies whether to continue training a previously
trained model (reads `model-name`).
--debug Specifies whether to use verbose logs during eval operation.
"""
import logging
import coloredlogs
from docopt import docopt
from trading_bot.agent import Agent
from trading_bot.methods import train_model, evaluate_model
from trading_bot.utils import (
get_stock_data,
format_currency,
format_position,
show_train_result,
switch_k_backend_device
)
def main(train_stock, val_stock, window_size, batch_size, ep_count,
strategy="t-dqn", model_name="model_debug", pretrained=False,
debug=False):
""" Trains the stock trading bot using Deep Q-Learning.
Please see https://arxiv.org/abs/1312.5602 for more details.
Args: [python train.py --help]
"""
agent = Agent(window_size, strategy=strategy, pretrained=pretrained, model_name=model_name)
train_data = get_stock_data(train_stock)
val_data = get_stock_data(val_stock)
initial_offset = val_data[1] - val_data[0]
for episode in range(1, ep_count + 1):
train_result = train_model(agent, episode, train_data, ep_count=ep_count,
batch_size=batch_size, window_size=window_size)
val_result, _ = evaluate_model(agent, val_data, window_size, debug)
show_train_result(train_result, val_result, initial_offset)
if __name__ == "__main__":
args = docopt(__doc__)
train_stock = args["<train-stock>"]
val_stock = args["<val-stock>"]
strategy = args["--strategy"]
window_size = int(args["--window-size"])
batch_size = int(args["--batch-size"])
ep_count = int(args["--episode-count"])
model_name = args["--model-name"]
pretrained = args["--pretrained"]
debug = args["--debug"]
coloredlogs.install(level="DEBUG")
switch_k_backend_device()
try:
main(train_stock, val_stock, window_size, batch_size,
ep_count, strategy=strategy, model_name=model_name,
pretrained=pretrained, debug=debug)
except KeyboardInterrupt:
print("Aborted!")