forked from DanialNejad/CustomMuJoCoEnviromentForRL
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
25 lines (22 loc) · 730 Bytes
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from stable_baselines3 import SAC
from ball_balance_env import BallBalanceEnv
import cv2
import imageio
env = BallBalanceEnv(render_mode="rgb_array")
model = SAC.load("sac_ball_balance.zip")
obs, info = env.reset()
frames = []
for _ in range(500):
action, _states = model.predict(obs, deterministic=True)
obs, reward, done, truncated, info = env.step(action)
image = env.render()
if _ % 5 == 0:
frames.append(image)
cv2.imshow("image", image)
cv2.waitKey(1)
if done or truncated:
obs, info = env.reset()
# uncomment to save result as gif
# with imageio.get_writer("media/test1.gif", mode="I") as writer:
# for idx, frame in enumerate(frames):
# writer.append_data(frame)