-
Notifications
You must be signed in to change notification settings - Fork 0
/
training_with_gym.py
83 lines (74 loc) · 2.82 KB
/
training_with_gym.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import gym
import mlagents_envs
# from baselines import deepq
# from baselines import logger
from stable_baselines3 import A2C
from stable_baselines3.common.env_util import make_vec_env
from gym_trainer.a2c import MyA2C
from mlagents_envs.environment import UnityEnvironment
from mlagents_envs.envs.unity_gym_env import UnityToGymWrapper
from torch.utils.tensorboard import SummaryWriter
import sys
import torch
import numpy as np
import random
from pathlib import Path
def main(run_id):
# unity_env = UnityEnvironment("/Users/rishimalhotra/projects/checker2_one_agent.app")
unity_env = UnityEnvironment()
torch.manual_seed(0)
np.random.seed(0)
random.seed(0)
vec_env = UnityToGymWrapper(unity_env, uint8_visual=True)
print('hi', vec_env._observation_space, isinstance(
vec_env.action_space, gym.spaces.Box), vec_env.action_space)
writer = SummaryWriter(f"runs/{run_id}")
print('run id: ', run_id)
# load_policy_network_checkpoint_path = "good_checkpoints/increasing_n_steps_to_15/step_25000/policy_network.pth"
# load_value_network_checkpoint_path = "good_checkpoints/increasing_n_steps_to_15/step_25000/value_network.pth"
load_policy_network_checkpoint_path = None
load_value_network_checkpoint_path = None
checkpoint_path = Path(f"checkpoints/{run_id}")
checkpoint_path.mkdir(parents=True, exist_ok=True)
model = MyA2C(vec_env,
writer,
n_steps=15,
num_rollouts_per_update=10,
load_policy_network_checkpoint_path=load_policy_network_checkpoint_path,
load_value_network_checkpoint_path=load_value_network_checkpoint_path)
model.learn(total_timesteps=50000,
policy_network_lr=7e-4,
value_network_lr=7e-4,
sigma_lr=1e-4,
ent_coef=1e-2,
checkpoint_path=checkpoint_path)
vec_env.close()
# model was getting better!
# model.save("a2c_cartpole")
# logger.configure('./logs') # Change to log in a different directory
# act = deepq.learn(
# env,
# "cnn", # For visual inputs
# lr=2.5e-4,
# total_timesteps=1000000,
# buffer_size=50000,
# exploration_fraction=0.05,
# exploration_final_eps=0.1,
# print_freq=20,
# train_freq=5,
# learning_starts=20000,
# target_network_update_freq=50,
# gamma=0.99,
# prioritized_replay=False,
# checkpoint_freq=1000,
# checkpoint_path='./logs', # Change to save model in a different directory
# dueling=True
# )
print("Saving model to unity_model.pkl")
# act.save("unity_model.pkl")
if __name__ == '__main__':
if len(sys.argv) > 1:
for arg in sys.argv[1:]:
if arg.startswith("--run-id="):
run_id = arg[9:]
main(run_id)