Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Gym Env and Implement RL Training #203

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,5 @@ cython_debug/
#.idea/

building_parameters/

runs/
4 changes: 1 addition & 3 deletions src/neuromancer/psl/building_envelope.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def equations(self, x, u, d):
y = self.C @ x + F
return x, y

def get_simulation_args(self, nsim, x0, U, D):
def get_simulation_args(self, nsim=None, x0=None, U=None, D=None):
nsim = self.nsim if nsim is None else nsim
x0 = self.get_x0() if x0 is None else x0
D = self.get_D(nsim+1) if D is None else D
Expand Down Expand Up @@ -235,8 +235,6 @@ def get_q(self, u):
print(n)
s = system(backend='torch')
out = s.simulate(nsim=5)


print({k: v.shape for k, v in out.items()})

for n, system in systems.items():
Expand Down
108 changes: 78 additions & 30 deletions src/neuromancer/psl/gym.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,94 @@
from scipy.io import loadmat
from gym import spaces, Env

import numpy as np
from neuromancer.psl.nonautonomous import systems, ODE_NonAutonomous
import torch
from gymnasium import spaces, Env
from gymnasium.envs.registration import register
from neuromancer.utils import seed_everything
from neuromancer.psl.building_envelope import BuildingEnvelope, systems

def disturbance(file='../../TimeSeries/disturb.mat', n_sim=8064):
return loadmat(file)['D'][:, :n_sim].T # n_sim X 3

class BuildingEnv(Env):
"""Custom Gym Environment for simulating building energy systems.

class GymWrapper(Env):
"""Custom Environment that follows gym interface"""
metadata = {'render.modes': ['human']}
This environment models the dynamics of a building's thermal system,
allowing for control actions to be taken and observing the resulting
thermal comfort levels. The environment adheres to the OpenAI Gym
interface, providing methods for stepping through the simulation,
resetting the state, and rendering the environment.

def __init__(self, simulator, U=None, ninit=None, nsim=None, ts=None, x0=None,
perturb=[lambda: 0. , lambda: 1.]):
Attributes:
metadata (dict): Information about the rendering modes available.
ymin (float): Minimum threshold for thermal comfort.
ymax (float): Maximum threshold for thermal comfort.
"""

def __init__(self, simulator, seed=None, fully_observable=False,
ymin=20.0, ymax=22.0, backend='numpy'):
super().__init__()
if isinstance(simulator, ODE_NonAutonomous):
self.simulator = simulator
if isinstance(simulator, BuildingEnvelope):
self.model = simulator
else:
self.simulator = systems[simulator](U=U, ninit=ninit, nsim=nsim, ts=ts, x0=x0, norm_func=norm_func)
self.action_space = spaces.Box(-np.inf, np.inf, shape=self.simulator.get_U().shape[-1], dtype=np.float32)
self.observation_space = spaces.Box(-np.inf, np.inf, shape=self.simulator.x0.shape,dtype=np.float32)
self.perturb = perturb
self.model = systems[simulator](seed=seed, backend=backend)
self.fully_observable = fully_observable
self.ymin = ymin
self.ymax = ymax
obs, _ = self.reset(seed=seed)
self.action_space = spaces.Box(
self.model.umin, self.model.umax, shape=self.model.umin.shape, dtype=np.float32)
self.observation_space = spaces.Box(
-np.inf, np.inf, shape=[len(obs)], dtype=np.float32)

def step(self, action):
self.x = self.A*np.asmatrix(self.x).reshape(4, 1) + self.B*action.T + self.E*(self.D[self.tstep].reshape(3,1))
self.y = (self.C * np.asmatrix(self.x)).flatten()
self.tstep += 1
observation = (self.y, self.x)[self.fully_observable].astype(np.float32)
self.X_out = np.concatenate([self.X_out, np.array(self.x.reshape([1, 4]))])
return np.array(observation).flatten(), self.reward(), self.tstep == self.X.shape[0], {'xout': self.X_out}
u = np.asarray(action)
self.d = self.get_disturbance()
# expect the model to accept both 1D arrays and 2D arrays
self.x, self.y = self.model(self.x, u, self.d)
self.t += 1
self.X_rec = np.append(self.X_rec, self.x)
obs = self.get_obs()
reward = self.get_reward(u, self.y)
done = self.t == self.model.nsim
truncated = False
return obs, reward, done, truncated, dict(X_rec=self.X_rec)

def get_reward(self, u, y, ymin=20.0, ymax=22.0):
# energy minimization
# u[0] is the nominal mass flow rate, u[1] is the temperature difference
q = self.model.get_q(u).sum() # q is the heat flow in W
k = np.sum(u != 0.0) # number of actions
action_loss = 0.01 * q + 0.01 * k

# thermal comfort
comfort_reward = 5. * np.sum((ymin < y) & (y < ymax)) # y in °C

def reset(self, dset='train'):
return comfort_reward - action_loss

def get_disturbance(self):
return self.model.get_D(1).flatten()

def get_obs(self):
obs_mask = torch.as_tensor(self.model.C.flatten(), dtype=torch.bool)
self.y = self.x[obs_mask]
d = self.d if self.fully_observable else self.d[self.model.d_idx]
obs = self.x if self.fully_observable else self.y
obs = np.hstack([obs, self.ymin, self.ymax, d])
return obs.astype(np.float32)

self.tstep = 0
observation = (self.y, self.x)[self.fully_observable].astype(np.float32)
self.X_out = np.empty(shape=[0, 4])
return np.array(observation).flatten()
def reset(self, seed=None, options=None):
seed_everything(seed)
self.t = 0
self.x = self.model.x0
self.d = self.get_disturbance()
self.X_rec = np.empty(shape=[0, 4])
return self.get_obs(), dict(X_rec=self.X_rec)

def render(self, mode='human'):
print('render')
pass

systems = {k: GymWrapper for k in GymWrapper.envs}

# allow the custom envs to be directly instantiated by gym.make(env_id)
for env_id in systems:
register(
env_id,
entry_point='neuromancer.psl.gym:BuildingEnv',
kwargs=dict(simulator=env_id),
)
85 changes: 85 additions & 0 deletions src/neuromancer/rl/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
### **Project Scheme: Hybrid Control System with Differential Predictive Control (DPC) and Deep Reinforcement Learning (DRL)**

---

### **Objective**:
The goal is to develop a hybrid control system by combining **Differential Predictive Control (DPC)** and **Deep Reinforcement Learning (DRL)** for efficient, robust control of a complex physical system. The system integrates a model based on **Ordinary Differential Equations (ODEs)** and **Neural State Space Models (NSSMs)** to augment control policies, with actor-critic DRL to optimize long-term strategy.

---

### **Components**:

1. **Physical System (Ground Truth)**:
- A real-world system with **limited access**, due to high cost, complexity, or experimental constraints.

2. **System Model**:
- A **system model** based on **ODEs** or **Stochastic Differential Equations (SDEs)** to capture uncertainties and perturbations. This serves as the predictive model for **DPC**.
- The model may include **neural network (NN) terms**, such as in **Universal Differential Equations (UDEs)**, trained using real-world data when available.
- **NSSMs** are used to model the system dynamics and provide future state predictions to augment the inputs to control models.

3. **Loss Function**:
- The objective function representing system performance (e.g., tracking error, energy consumption). This drives DPC optimization and defines the DRL reward.

4. **Policy Model (Actor Network)**:
- An NN-based **control policy** that outputs actions. First trained via **DPC**, and later improved using **DRL** (e.g., PPO or SAC).
- The policy network receives **current states** and **NSSM-predicted future states** as inputs to enable foresight in decision-making.

5. **Value Model (Critic Network)**:
- A **critic network** used in DRL to estimate long-term returns. It also receives **augmented inputs** from current states and NSSM predictions.

---

### **Workflow**:

#### **1. Model the Physical System Using ODE**:
- **System Model**: Model the physical system's dynamics with **ODEs** (optionally incorporating stochastic elements to capture uncertainties). This serves as the **system model** for short-term control in DPC.
- **NN Components**: If necessary, use real-world data to train any **neural network terms** in the system model.

---

#### **2. Gather Real-World and Simulated Data**:
- **Data Collection**: Gather real-world data from the physical system and augment it with simulated data from the ODE-based system model.
- **Dataset**: Combine both real and simulated data into a dataset for NSSM and DPC training.

---

#### **3. Train the Neural State Space Model (NSSM)**:
- **NSSM Training**: Train the **NSSM** using the collected dataset. The NSSM learns to predict future states of the system from current states and control inputs.
- **Input Augmentation**: Use NSSM-predicted next states to augment the inputs to the **policy model** (in DPC and DRL) and the **value model** (in DRL).
- This enables proactive decision-making by incorporating future state predictions into control actions.

---

#### **4. Pre-train the Policy Network Using Differential Predictive Control (DPC)**:
- **DPC Training**: Pre-train the policy network with **DPC**, optimizing the control actions over a finite horizon using the **system model** (based on ODEs).
- **NSSM Predictions**: Augment the policy network's inputs with **NSSM-predicted future states** to improve decision-making.
- **Respect Constraints**: Ensure that the DPC respects system constraints, such as safety limits or actuator boundaries.

---

#### **5. Train Policy Network Using DRL**:
- **Policy Initialization**: Initialize the **actor network** (policy) using the DPC-trained policy for a strong starting point.
- **Stochastic Exploration**: Ensure the policy includes some stochasticity to allow for exploration beyond the DPC-optimized policy.
- **DRL Optimization**: Refine the policy using DRL methods like **PPO** or **SAC** to maximize long-term performance.
- **Reward Function**:
- Define the reward as the **difference in losses** between the DPC and DRL policies:
\[
R = \mathcal{L}_{\text{DPC}} - \mathcal{L}_{\text{DRL}}
\]
- This encourages the RL agent to improve over the DPC baseline policy.
- **Critic Network**: Randomly initialize the **critic network**, which will be trained alongside the policy during DRL.

---

### **Final Summary**:

1. **Model the Physical System**: Use ODEs (with stochastic elements if necessary) to represent system dynamics.
2. **Gather Data**: Collect real-world and simulated data for model training and policy optimization.
3. **Train NSSM**: Train the NSSM to predict future states, augmenting inputs to the control models.
4. **Pre-train Policy with DPC**: Use DPC to pre-train the policy using the system model.
5. **Train Policy with DRL**: Refine the policy using DRL (PPO or SAC), optimizing with the reward defined as the loss difference between DPC and DRL policies.

---

### **Outcome**:
This hybrid framework combines the short-term, constraint-aware optimization of **DPC** with the long-term adaptability of **DRL**. By augmenting inputs with **NSSM-predicted future states**, the system gains foresight, allowing for more proactive, robust control strategies. The reward structure, comparing DPC and DRL policy performance, ensures continual improvement over the baseline.
Empty file added src/neuromancer/rl/__init__.py
Empty file.
31 changes: 31 additions & 0 deletions src/neuromancer/rl/diagram.dot
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
digraph Hybrid_Control_System {
rankdir=LR; // Left to Right layout
node [shape=box, style=rounded];

// Components
physical_system [label="Physical System"];
system_model [label="System Model (ODE/SDE)"];
data [label="Time Series Dataset"];
nssm [label="Neural State Space Model (NSSM)"];
policy_model [label="Policy Model (Actor)"];
value_model [label="Value Model (Critic)"];
loss [label="System Loss"];
reward [label="RL Reward"];

// Workflow connections
physical_system -> data [label="Generate Data"];
physical_system -> system_model [label="Modelling"];
system_model -> data [label="Simulate Data"];
system_model -> loss [label="Loss Function"];
policy_model -> system_model [label="Decision Making"];
data -> nssm [label="NSSM Training Data"];
data -> system_model [label="DPC Training Data"];
nssm -> policy_model [label="Augment Inputs"];
nssm -> value_model [label="Augment Inputs"];
loss -> reward[label="Loss(DPC) - Loss(DRL)"];
loss -> policy_model[label="Optimize by DPC"];
reward -> value_model [label="Cumulative Return"];
system_model -> value_model [label="Observation"];
system_model -> policy_model [label="Observation"];
value_model -> policy_model [label="Optimize by DRL"];
}
Loading