-
Notifications
You must be signed in to change notification settings - Fork 1
/
cartpole_simscape.py
131 lines (109 loc) · 4.3 KB
/
cartpole_simscape.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from simulink_gym import SimulinkEnv, Observation, Observations
from gym.spaces import Discrete
from pathlib import Path
import numpy as np
import math
# Define example environment:
class CartPoleSimscape(SimulinkEnv):
"""Classic Cart Pole Control Environment implemented in Matlab/Simulink/Simscape.
Observation:
Type: Box(4)
Num Observation Min Max
0 Cart Position -4.8 4.8
1 Cart Velocity -Inf Inf
2 Pole Angle -0.418 rad (-24 deg) 0.418 rad (24 deg)
3 Pole Angular Velocity -Inf Inf
Actions:
Type: Discrete(2)
Num Action
0 Push cart to the left
1 Push cart to the right
Note: The amount the velocity that is reduced or increased is not
fixed; it depends on the angle the pole is pointing. This is because
the center of gravity of the pole increases the amount of energy needed
to move the cart underneath it
Reward:
Reward is 1 for every step taken, including the termination step
"""
def __init__(
self,
stop_time: float = 10.0,
step_size: float = 0.02,
model_debug: bool = False,
):
"""Simscape implementation of the classic Cart Pole environment.
Parameters:
stop_time: float, default 10
maximum simulation duration in seconds
step_size: float, default 0.02
size of simulation step in seconds
model_debug: bool, default False
Flag for setting up the model debug mode (see Readme.md for details)
"""
super().__init__(
model_path=Path(__file__)
.parent.absolute()
.joinpath("cartpole_simscape.slx"),
model_debug=model_debug,
)
# Define action space:
self.action_space = Discrete(2)
# Define state and observations:
self.max_cart_position = 2.4
max_pole_angle_deg = 12
self.max_pole_angle_rad = max_pole_angle_deg * math.pi / 180.0
self.observations = Observations(
[
Observation(
"pos",
-self.max_cart_position * 2.0,
self.max_cart_position * 2.0,
"x_0",
self.set_workspace_variable,
),
Observation("vel", -np.inf, np.inf, "v_0", self.set_workspace_variable),
Observation(
"theta",
-self.max_pole_angle_rad * 2.0,
self.max_pole_angle_rad * 2.0,
"theta_0",
self.set_workspace_variable,
),
Observation(
"omega", -np.inf, np.inf, "omega_0", self.set_workspace_variable
),
]
)
# Get initial state from defined observations:
self.state = self.observations.initial_state
# Set simulation parameters:
self.set_model_parameter("StopTime", stop_time)
self.set_workspace_variable("step_size", step_size)
def reset(self):
# Resample initial state:
self.observations.initial_state = np.random.uniform(
low=-0.05, high=0.05, size=(4,)
)
# Call common reset:
super()._reset()
# Return reshaped state. Needed for use as tf.model input:
return self.state
def step(self, action):
"""Method for stepping the simulation."""
action = int(action)
state, simulation_time, terminated, truncated = self.sim_step(action)
# Check all termination conditions:
current_pos = state[0]
current_theta = state[2]
done = bool(
terminated
or truncated
or current_pos < -self.max_cart_position
or current_pos > self.max_cart_position
or current_theta < -self.max_pole_angle_rad
or current_theta > self.max_pole_angle_rad
)
# Receive reward for every step inside state and time limits:
reward = 1
info = {"simulation time [s]": simulation_time}
return state, reward, done, info