-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathtraining_continue.py
executable file
·139 lines (108 loc) · 5.01 KB
/
training_continue.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import argparse
import copy
import json
import os
from typing import Callable
from datetime import datetime
from pathlib import Path
import json
import pyaml
import torch
import yaml
from stable_baselines3.common.utils import get_device
from stable_baselines3.ppo import MlpPolicy
from torch import nn
import pybullet_data
import pybullet_envs # register pybullet envs from bullet3
from NerveNet.graph_util.mujoco_parser_settings import ControllerOption, EmbeddingOption, RootRelationOption
from NerveNet.models import nerve_net_conv
from NerveNet.policies import register_policies
import NerveNet.gym_envs.pybullet.register_disability_envs
import gym
from stable_baselines3 import PPO, A2C
from stable_baselines3.common.callbacks import CheckpointCallback, CallbackList
from stable_baselines3.common.env_util import make_vec_env
from util import LoggingCallback
algorithms = dict(A2C=A2C, PPO=PPO)
activation_functions = dict(Tanh=nn.Tanh, ReLU=nn.ReLU)
controller_option = dict(shared=ControllerOption.SHARED,
seperate=ControllerOption.SEPERATE,
unified=ControllerOption.UNIFIED)
embedding_option = dict(shared=EmbeddingOption.SHARED,
unified=EmbeddingOption.UNIFIED)
root_option = dict(none=RootRelationOption.NONE,
body=RootRelationOption.BODY,
unified=RootRelationOption.ALL)
def train(args):
cuda_availability = torch.cuda.is_available()
print('\n*************************')
print('`CUDA` available: {}'.format(cuda_availability))
print('Device specified: {}'.format(args.device))
print('*************************\n')
# load the config of the trained model:
with open(args.pretrained_output / "train_arguments.yaml") as yaml_data:
pretrain_arguments = yaml.load(yaml_data,
Loader=yaml.FullLoader)
pretrained_model = algorithms[pretrain_arguments["alg"]].load(
args.pretrained_output / "".join(pretrain_arguments["model_name"].split(".")[:-1]), device='cpu')
# Prepare tensorboard logging
log_name = '{}_{}'.format(
pretrain_arguments["experiment_name"], datetime.now().strftime('%d-%m_%H-%M-%S'))
run_dir = args.tensorboard_log + "/" + log_name
Path(run_dir).mkdir(parents=True, exist_ok=True)
callbacks = []
# callbacks.append(CheckpointCallback(
# save_freq=1000000, save_path=run_dir, name_prefix='rl_model'))
callbacks.append(LoggingCallback(logpath=run_dir))
train_args = copy.copy(pretrain_arguments)
pyaml.dump(train_args, open(
os.path.join(run_dir, 'train_arguments.yaml'), 'w'))
# Create the vectorized environment
n_envs = pretrain_arguments["n_envs"] # Number of processes to use
env = make_vec_env(pretrain_arguments["task_name"], n_envs=n_envs)
pretrained_model.env = env
pretrained_model.learn(total_timesteps=args.total_timesteps,
callback=callbacks,
tb_log_name=log_name)
pretrained_model.save(os.path.join(args.tensorboard_log +
"/" + log_name, args.model_name))
def dir_path(path):
if os.path.isdir(path):
return Path(path)
else:
raise argparse.ArgumentTypeError(
f"readable_dir:{path} is not a valid path")
def parse_arguments():
p = argparse.ArgumentParser()
p.add_argument('--pretrained_output',
help="The directory where the pretrained output & configs were logged to",
type=dir_path,
default='runs/GNN_PPO_inp_32_pro_32164_pol_16_val_64_64_N2048_B512_lr2e-04_GNNValue_0_EmbOpt_shared_mode_action_per_controller_Epochs_10_Nenvs_8_GRU_AntBulletEnv-v0_09-03_18-00-53')
p.add_argument("--total_timesteps",
help="The total number of samples (env steps) to train on",
type=int,
default=4000000)
p.add_argument('--tensorboard_log',
help='the log location for tensorboard (if None, no logging)',
default="runs")
p.add_argument('--n_envs',
help="Number of environments to run in parallel to collect rollout. Each environment requires one CPU",
type=int,
default=8)
p.add_argument('--device',
help='Device (cpu, cuda, ...) on which the code should be run.'
'Setting it to auto, the code will be run on the GPU if possible.',
default="auto")
p.add_argument('--experiment_name',
help='name to append to the tensorboard logs directory',
default=None)
p.add_argument('--experiment_name_suffix',
help='name to append to the tensorboard logs directory',
default=None)
p.add_argument('--model_name',
help='The name of your saved model',
default='model.zip')
args = p.parse_args()
return args
if __name__ == '__main__':
train(parse_arguments())