-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference_with_gym.py
44 lines (34 loc) · 1.51 KB
/
inference_with_gym.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import gym
import mlagents_envs
# from baselines import deepq
# from baselines import logger
from stable_baselines3 import A2C
from stable_baselines3.common.env_util import make_vec_env
from gym_trainer.a2c import MyA2C
from mlagents_envs.environment import UnityEnvironment
from mlagents_envs.envs.unity_gym_env import UnityToGymWrapper
from torch.utils.tensorboard import SummaryWriter
import sys
import torch
import numpy as np
import random
from pathlib import Path
def main():
# unity_env = UnityEnvironment("/Users/rishimalhotra/projects/checker2_one_agent.app")
unity_env = UnityEnvironment()
torch.manual_seed(0)
np.random.seed(0)
random.seed(0)
vec_env = UnityToGymWrapper(unity_env, uint8_visual=True)
# avg length: 43.
# policy_network_checkpoint_path = "good_checkpoints/using_log_exp_with_lower_lr_for_sigma/step_18003/policy_network.pth"
# avg length: 43.
# load_policy_network_checkpoint_path = "good_checkpoints/increasing_n_steps_to_15/step_25000/policy_network.pth"
load_policy_network_checkpoint_path = "good_checkpoints/multiple_rollouts_per_batch/step_49722/policy_network.pth"
model = MyA2C(vec_env, None, n_steps=10, load_policy_network_checkpoint_path=load_policy_network_checkpoint_path, num_rollouts_per_update=32)
average_episode_length, std_episode_length = model.collect_rollouts_for_inference(num_episodes=20, deterministic=True)
print('average episode length:', average_episode_length)
# model.save("a2c_cartpole")
vec_env.close()
if __name__ == '__main__':
main()