From ee8284a3d1858cf0aa1f5d830eaf87dbcf9019ae Mon Sep 17 00:00:00 2001 From: tztsai <1213815284@qq.com> Date: Tue, 22 Oct 2024 18:40:37 +0100 Subject: [PATCH 01/17] update gym env to make it work --- src/neuromancer/psl/building_envelope.py | 4 +- src/neuromancer/psl/gym.py | 79 +++++++++++++++--------- 2 files changed, 51 insertions(+), 32 deletions(-) diff --git a/src/neuromancer/psl/building_envelope.py b/src/neuromancer/psl/building_envelope.py index 4194798d..6017ce86 100644 --- a/src/neuromancer/psl/building_envelope.py +++ b/src/neuromancer/psl/building_envelope.py @@ -167,7 +167,7 @@ def equations(self, x, u, d): y = self.C @ x + F return x, y - def get_simulation_args(self, nsim, x0, U, D): + def get_simulation_args(self, nsim=None, x0=None, U=None, D=None): nsim = self.nsim if nsim is None else nsim x0 = self.get_x0() if x0 is None else x0 D = self.get_D(nsim+1) if D is None else D @@ -235,8 +235,6 @@ def get_q(self, u): print(n) s = system(backend='torch') out = s.simulate(nsim=5) - - print({k: v.shape for k, v in out.items()}) for n, system in systems.items(): diff --git a/src/neuromancer/psl/gym.py b/src/neuromancer/psl/gym.py index f821d7da..93c6ef98 100644 --- a/src/neuromancer/psl/gym.py +++ b/src/neuromancer/psl/gym.py @@ -2,45 +2,66 @@ from gym import spaces, Env import numpy as np -from neuromancer.psl.nonautonomous import systems, ODE_NonAutonomous +from neuromancer.psl.building_envelope import BuildingEnvelope, systems -def disturbance(file='../../TimeSeries/disturb.mat', n_sim=8064): - return loadmat(file)['D'][:, :n_sim].T # n_sim X 3 +class BuildingEnv(Env): + """Custom Gym Environment for simulating building energy systems. -class GymWrapper(Env): - """Custom Environment that follows gym interface""" - metadata = {'render.modes': ['human']} + This environment models the dynamics of a building's thermal system, + allowing for control actions to be taken and observing the resulting + thermal comfort levels. The environment adheres to the OpenAI Gym + interface, providing methods for stepping through the simulation, + resetting the state, and rendering the environment. - def __init__(self, simulator, U=None, ninit=None, nsim=None, ts=None, x0=None, - perturb=[lambda: 0. , lambda: 1.]): + Attributes: + metadata (dict): Information about the rendering modes available. + ymin (float): Minimum threshold for thermal comfort. + ymax (float): Maximum threshold for thermal comfort. + """ + def __init__(self, simulator, seed=None, fully_observable=True): super().__init__() - if isinstance(simulator, ODE_NonAutonomous): - self.simulator = simulator + if isinstance(simulator, BuildingEnvelope): + self.sys = simulator else: - self.simulator = systems[simulator](U=U, ninit=ninit, nsim=nsim, ts=ts, x0=x0, norm_func=norm_func) - self.action_space = spaces.Box(-np.inf, np.inf, shape=self.simulator.get_U().shape[-1], dtype=np.float32) - self.observation_space = spaces.Box(-np.inf, np.inf, shape=self.simulator.x0.shape,dtype=np.float32) - self.perturb = perturb + self.sys = systems[simulator](seed=seed) + self.action_space = spaces.Box( + self.sys.umin, self.sys.umax, shape=self.sys.umin.shape, dtype=np.float32) + self.observation_space = spaces.Box( + -np.inf, np.inf, shape=self.sys.x0.shape, dtype=np.float32) + self.fully_observable = fully_observable + self.reset() def step(self, action): - self.x = self.A*np.asmatrix(self.x).reshape(4, 1) + self.B*action.T + self.E*(self.D[self.tstep].reshape(3,1)) - self.y = (self.C * np.asmatrix(self.x)).flatten() - self.tstep += 1 - observation = (self.y, self.x)[self.fully_observable].astype(np.float32) - self.X_out = np.concatenate([self.X_out, np.array(self.x.reshape([1, 4]))]) - return np.array(observation).flatten(), self.reward(), self.tstep == self.X.shape[0], {'xout': self.X_out} + u = np.asarray(action) + d = self.sys.get_D(1).flatten() + self.x, self.y = self.sys.equations(self.x, u, d) + self.t += 1 + self.X_rec = np.append(self.X_rec, self.x) + reward = self.reward(u, self.y) + done = self.t == self.sys.nsim + return self.obs, reward, done, dict(X_rec=self.X_rec) + + def reward(self, u, y, ymin=21.0, ymax=23.0): + # energy minimization + action_loss = 0.1 * np.sum(u > 0.0) - def reset(self, dset='train'): + # thermal comfort + inbound_reward = 5. * np.sum((ymin < y) & (y < ymax)) - self.tstep = 0 - observation = (self.y, self.x)[self.fully_observable].astype(np.float32) - self.X_out = np.empty(shape=[0, 4]) - return np.array(observation).flatten() + return inbound_reward - action_loss + + @property + def obs(self): + return (self.y, self.x)[self.fully_observable].astype(np.float32) - def render(self, mode='human'): - print('render') - -systems = {k: GymWrapper for k in GymWrapper.envs} + def reset(self): + self.t = 0 + self.x = self.sys.x0 + self.y = self.sys.C * self.x + self.X_rec = np.empty(shape=[0, 4]) + return self.obs + def render(self, mode='human'): + pass From 19bc8dc2ffb33de16fdb07733b83e736b61b4aae Mon Sep 17 00:00:00 2001 From: tztsai <1213815284@qq.com> Date: Tue, 22 Oct 2024 22:11:44 +0100 Subject: [PATCH 02/17] update gym env, todo, requirements --- .gitignore | 3 +++ src/neuromancer/psl/gym.py | 19 ++++++++++--------- src/neuromancer/rl/TODO.md | 20 ++++++++++++++++++++ src/neuromancer/rl/requirements.txt | 2 ++ 4 files changed, 35 insertions(+), 9 deletions(-) create mode 100644 src/neuromancer/rl/TODO.md create mode 100644 src/neuromancer/rl/requirements.txt diff --git a/.gitignore b/.gitignore index e10aa4e8..f0d1b9da 100644 --- a/.gitignore +++ b/.gitignore @@ -169,3 +169,6 @@ cython_debug/ #.idea/ building_parameters/ + +*.dot +*.png \ No newline at end of file diff --git a/src/neuromancer/psl/gym.py b/src/neuromancer/psl/gym.py index 93c6ef98..f2beef20 100644 --- a/src/neuromancer/psl/gym.py +++ b/src/neuromancer/psl/gym.py @@ -23,24 +23,25 @@ class BuildingEnv(Env): def __init__(self, simulator, seed=None, fully_observable=True): super().__init__() if isinstance(simulator, BuildingEnvelope): - self.sys = simulator + self.model = simulator else: - self.sys = systems[simulator](seed=seed) + self.model = systems[simulator](seed=seed) self.action_space = spaces.Box( - self.sys.umin, self.sys.umax, shape=self.sys.umin.shape, dtype=np.float32) + self.model.umin, self.model.umax, shape=self.model.umin.shape, dtype=np.float32) self.observation_space = spaces.Box( - -np.inf, np.inf, shape=self.sys.x0.shape, dtype=np.float32) + -np.inf, np.inf, shape=self.model.x0.shape, dtype=np.float32) self.fully_observable = fully_observable self.reset() def step(self, action): u = np.asarray(action) - d = self.sys.get_D(1).flatten() - self.x, self.y = self.sys.equations(self.x, u, d) + d = self.model.get_D(1).flatten() + # model should accept either 1D arrays or 2D (n, 1) arrays + self.x, self.y = self.model(self.x, u, d) self.t += 1 self.X_rec = np.append(self.X_rec, self.x) reward = self.reward(u, self.y) - done = self.t == self.sys.nsim + done = self.t == self.model.nsim return self.obs, reward, done, dict(X_rec=self.X_rec) def reward(self, u, y, ymin=21.0, ymax=23.0): @@ -58,8 +59,8 @@ def obs(self): def reset(self): self.t = 0 - self.x = self.sys.x0 - self.y = self.sys.C * self.x + self.x = self.model.x0 + self.y = self.model.C * self.x self.X_rec = np.empty(shape=[0, 4]) return self.obs diff --git a/src/neuromancer/rl/TODO.md b/src/neuromancer/rl/TODO.md new file mode 100644 index 00000000..1236b368 --- /dev/null +++ b/src/neuromancer/rl/TODO.md @@ -0,0 +1,20 @@ +- DPC as a Safety Layer for PPO + In this approach, PPO is the primary control policy that interacts with the environment to maximize long-term rewards. DPC acts as a safety layer that monitors PPO’s control actions and corrects them if they violate constraints or if the system is predicted to become unstable. + How it works: + PPO generates control actions based on its learned policy, focusing on maximizing long-term rewards through exploration and learning. + DPC monitors PPO’s actions in real time. Using a neural network, DPC predicts the future states of the system over a short horizon. If PPO’s proposed action leads to unsafe or suboptimal behavior (e.g., violating constraints or causing instability), DPC overrides PPO’s action with a safer one. + Fallback mechanism: If PPO’s action is safe and within the acceptable range, it is used. If not, DPC’s optimal control action is applied instead. +- DPC for Real-Time Control, PPO for Long-Term Policy Learning + In this approach, DPC is used for immediate predictive control, ensuring that the system adheres to constraints and is optimized by immediate feedback. PPO is responsible for learning the long-term control policy, helping the system adapt to changes and improve its performance over time. The control policy learned by PPO can guide or enhance the decisions made by DPC. + How it works: + DPC handles short-term optimization: At each time step, DPC uses a neural network to predict the system's future states over a short horizon and computes the optimal control action that minimizes a cost function while respecting constraints. + PPO updates the long-term policy: Over time, PPO learns a control policy that maximizes cumulative rewards by interacting with the environment. PPO can provide feedback to DPC in the form of improved control actions or policy adjustments. + Policy blending: You can blend the control policies from PPO and DPC by weighting them. +- PPO for Model Learning in DPC + In this approach, PPO is used to improve the neural network model used in DPC. While DPC typically relies on a pre-trained neural network to predict future states, PPO can continuously update and refine this model based on its interactions with the environment. + How it works: + DPC predicts short-term states using a neural network, and it computes optimal control actions based on these predictions. + PPO refines the neural network model: As PPO interacts with the environment, it improves its understanding of the system’s dynamics. PPO can then update the neural network used by DPC, making the predictions more accurate and improving DPC’s control performance. + Online learning: PPO continuously learns from real-time data, allowing DPC to adapt to changing system dynamics, external disturbances, or shifts in the environment. + - [ ] train DPC first, and then train the critic network using DPC policy, and use the DPC policy as initialization of PPO policy + - [ ] Use NSSM to predict the next model state and add it as extra input to the policy model. \ No newline at end of file diff --git a/src/neuromancer/rl/requirements.txt b/src/neuromancer/rl/requirements.txt new file mode 100644 index 00000000..fa8ba411 --- /dev/null +++ b/src/neuromancer/rl/requirements.txt @@ -0,0 +1,2 @@ +gym +ray[rllib] \ No newline at end of file From b70aafe59c2ef286cea8251cb49c167576900b00 Mon Sep 17 00:00:00 2001 From: tztsai <1213815284@qq.com> Date: Wed, 23 Oct 2024 14:32:05 +0100 Subject: [PATCH 03/17] running ppo in BuildingEnv --- .gitignore | 3 +- src/neuromancer/psl/gym.py | 26 +- src/neuromancer/rl/__init__.py | 5 + src/neuromancer/rl/ppo.py | 358 ++++++++++++++++++++++++++++ src/neuromancer/rl/requirements.txt | 5 +- src/neuromancer/rl/trainer.py | 252 ++++++++++++++++++++ src/neuromancer/utils.py | 17 +- 7 files changed, 653 insertions(+), 13 deletions(-) create mode 100644 src/neuromancer/rl/__init__.py create mode 100644 src/neuromancer/rl/ppo.py create mode 100644 src/neuromancer/rl/trainer.py diff --git a/.gitignore b/.gitignore index f0d1b9da..1d9869a2 100644 --- a/.gitignore +++ b/.gitignore @@ -171,4 +171,5 @@ cython_debug/ building_parameters/ *.dot -*.png \ No newline at end of file +*.png +runs/ \ No newline at end of file diff --git a/src/neuromancer/psl/gym.py b/src/neuromancer/psl/gym.py index f2beef20..02814fe5 100644 --- a/src/neuromancer/psl/gym.py +++ b/src/neuromancer/psl/gym.py @@ -1,7 +1,7 @@ -from scipy.io import loadmat -from gym import spaces, Env - import numpy as np +from gymnasium import spaces, Env +from gymnasium.envs.registration import register +from neuromancer.utils import seed_everything from neuromancer.psl.building_envelope import BuildingEnvelope, systems @@ -31,18 +31,19 @@ def __init__(self, simulator, seed=None, fully_observable=True): self.observation_space = spaces.Box( -np.inf, np.inf, shape=self.model.x0.shape, dtype=np.float32) self.fully_observable = fully_observable - self.reset() + self.reset(seed=seed) def step(self, action): u = np.asarray(action) d = self.model.get_D(1).flatten() - # model should accept either 1D arrays or 2D (n, 1) arrays + # expect the model to accept both 1D arrays and 2D arrays self.x, self.y = self.model(self.x, u, d) self.t += 1 self.X_rec = np.append(self.X_rec, self.x) reward = self.reward(u, self.y) done = self.t == self.model.nsim - return self.obs, reward, done, dict(X_rec=self.X_rec) + truncated = False + return self.obs, reward, done, truncated, dict(X_rec=self.X_rec) def reward(self, u, y, ymin=21.0, ymax=23.0): # energy minimization @@ -57,12 +58,21 @@ def reward(self, u, y, ymin=21.0, ymax=23.0): def obs(self): return (self.y, self.x)[self.fully_observable].astype(np.float32) - def reset(self): + def reset(self, seed=None, options=None): + seed_everything(seed) self.t = 0 self.x = self.model.x0 self.y = self.model.C * self.x self.X_rec = np.empty(shape=[0, 4]) - return self.obs + return self.obs, dict(X_rec=self.X_rec) def render(self, mode='human'): pass + + +for env_id in systems: + register( + env_id, + entry_point='neuromancer.psl.gym:BuildingEnv', + kwargs=dict(simulator=env_id), + ) \ No newline at end of file diff --git a/src/neuromancer/rl/__init__.py b/src/neuromancer/rl/__init__.py new file mode 100644 index 00000000..9bded9e8 --- /dev/null +++ b/src/neuromancer/rl/__init__.py @@ -0,0 +1,5 @@ +from neuromancer.psl.gym import BuildingEnv +from neuromancer.rl.ppo import run + +if __name__ == "__main__": + run() \ No newline at end of file diff --git a/src/neuromancer/rl/ppo.py b/src/neuromancer/rl/ppo.py new file mode 100644 index 00000000..014af275 --- /dev/null +++ b/src/neuromancer/rl/ppo.py @@ -0,0 +1,358 @@ +# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_continuous_actionpy +import os +import time +import tqdm +import random +from dataclasses import dataclass + +import gymnasium as gym +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +import tyro +from torch.distributions.normal import Normal +from torch.utils.tensorboard import SummaryWriter + + +@dataclass +class Args: + exp_name: str = os.path.basename(__file__)[: -len(".py")] + """the name of this experiment""" + seed: int = 1 + """seed of the experiment""" + torch_deterministic: bool = True + """if toggled, `torch.backends.cudnn.deterministic=False`""" + cuda: bool = True + """if toggled, cuda will be enabled by default""" + track: bool = False + """if toggled, this experiment will be tracked with Weights and Biases""" + wandb_project_name: str = "cleanRL" + """the wandb's project name""" + wandb_entity: str = None + """the entity (team) of wandb's project""" + capture_video: bool = False + """whether to capture videos of the agent performances (check out `videos` folder)""" + save_model: bool = False + """whether to save model into the `runs/{run_name}` folder""" + upload_model: bool = False + """whether to upload the saved model to huggingface""" + hf_entity: str = "" + """the user or org name of the model repository from the Hugging Face Hub""" + + # Algorithm specific arguments + env_id: str = "SimpleSingleZone" + """the id of the environment""" + total_timesteps: int = 1000000 + """total timesteps of the experiments""" + learning_rate: float = 3e-4 + """the learning rate of the optimizer""" + num_envs: int = 1 + """the number of parallel game environments""" + num_steps: int = 2048 + """the number of steps to run in each environment per policy rollout""" + anneal_lr: bool = True + """Toggle learning rate annealing for policy and value networks""" + gamma: float = 0.99 + """the discount factor gamma""" + gae_lambda: float = 0.95 + """the lambda for the general advantage estimation""" + num_minibatches: int = 32 + """the number of mini-batches""" + update_epochs: int = 10 + """the K epochs to update the policy""" + norm_adv: bool = True + """Toggles advantages normalization""" + clip_coef: float = 0.2 + """the surrogate clipping coefficient""" + clip_vloss: bool = True + """Toggles whether or not to use a clipped loss for the value function, as per the paper.""" + ent_coef: float = 0.0 + """coefficient of the entropy""" + vf_coef: float = 0.5 + """coefficient of the value function""" + max_grad_norm: float = 0.5 + """the maximum norm for the gradient clipping""" + target_kl: float = None + """the target KL divergence threshold""" + + # to be filled in runtime + batch_size: int = 0 + """the batch size (computed in runtime)""" + minibatch_size: int = 0 + """the mini-batch size (computed in runtime)""" + num_iterations: int = 0 + """the number of iterations (computed in runtime)""" + + +def make_env(env_id, idx, capture_video, run_name, gamma): + def thunk(): + if capture_video and idx == 0: + env = gym.make(env_id, render_mode="rgb_array") + env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") + else: + env = gym.make(env_id) + env = gym.wrappers.FlattenObservation(env) # deal with dm_control's Dict observation space + env = gym.wrappers.RecordEpisodeStatistics(env) + env = gym.wrappers.ClipAction(env) + env = gym.wrappers.NormalizeObservation(env) + env = gym.wrappers.TransformObservation(env, lambda obs: np.clip(obs, -10, 10)) + env = gym.wrappers.NormalizeReward(env, gamma=gamma) + env = gym.wrappers.TransformReward(env, lambda reward: np.clip(reward, -10, 10)) + return env + + return thunk + + +def layer_init(layer, std=np.sqrt(2), bias_const=0.0): + torch.nn.init.orthogonal_(layer.weight, std) + torch.nn.init.constant_(layer.bias, bias_const) + return layer + + +class Agent(nn.Module): + def __init__(self, envs): + super().__init__() + self.critic = nn.Sequential( + layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)), + nn.Tanh(), + layer_init(nn.Linear(64, 64)), + nn.Tanh(), + layer_init(nn.Linear(64, 1), std=1.0), + ) + self.actor_mean = nn.Sequential( + layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)), + nn.Tanh(), + layer_init(nn.Linear(64, 64)), + nn.Tanh(), + layer_init(nn.Linear(64, np.prod(envs.single_action_space.shape)), std=0.01), + ) + self.actor_logstd = nn.Parameter(torch.zeros(1, np.prod(envs.single_action_space.shape))) + + def get_value(self, x): + return self.critic(x) + + def get_action_and_value(self, x, action=None): + action_mean = self.actor_mean(x) + action_logstd = self.actor_logstd.expand_as(action_mean) + action_std = torch.exp(action_logstd) + probs = Normal(action_mean, action_std) + if action is None: + action = probs.sample() + return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x) + + +def run(): + args = tyro.cli(Args) + args.batch_size = int(args.num_envs * args.num_steps) + args.minibatch_size = int(args.batch_size // args.num_minibatches) + args.num_iterations = args.total_timesteps // args.batch_size + run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=vars(args), + name=run_name, + monitor_gym=True, + save_code=True, + ) + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + + # TRY NOT TO MODIFY: seeding + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.backends.cudnn.deterministic = args.torch_deterministic + + device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") + + # env setup + envs = gym.vector.SyncVectorEnv( + [make_env(args.env_id, i, args.capture_video, run_name, args.gamma) for i in range(args.num_envs)] + ) + assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" + + agent = Agent(envs).to(device) + optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5) + + # ALGO Logic: Storage setup + obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device) + actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device) + logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device) + rewards = torch.zeros((args.num_steps, args.num_envs)).to(device) + dones = torch.zeros((args.num_steps, args.num_envs)).to(device) + values = torch.zeros((args.num_steps, args.num_envs)).to(device) + + # TRY NOT TO MODIFY: start the game + global_step = 0 + start_time = time.time() + next_obs, _ = envs.reset(seed=args.seed) + next_obs = torch.Tensor(next_obs).to(device) + next_done = torch.zeros(args.num_envs).to(device) + + prog_bar = tqdm.trange(1, args.num_iterations + 1, desc="PPO training") + prog_postfix = {} + for iteration in prog_bar: + # Annealing the rate if instructed to do so. + if args.anneal_lr: + frac = 1.0 - (iteration - 1.0) / args.num_iterations + lrnow = frac * args.learning_rate + optimizer.param_groups[0]["lr"] = lrnow + + for step in range(0, args.num_steps): + global_step += args.num_envs + obs[step] = next_obs + dones[step] = next_done + + # ALGO LOGIC: action logic + with torch.no_grad(): + action, logprob, _, value = agent.get_action_and_value(next_obs) + values[step] = value.flatten() + actions[step] = action + logprobs[step] = logprob + + # TRY NOT TO MODIFY: execute the game and log data. + next_obs, reward, terminations, truncations, infos = envs.step(action.cpu().numpy()) + next_done = np.logical_or(terminations, truncations) + rewards[step] = torch.tensor(reward).to(device).view(-1) + next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(next_done).to(device) + + if "final_info" in infos: + for info in infos["final_info"]: + if info and "episode" in info: + prog_postfix.update(steps=global_step, reward=info["episode"]["r"].mean()) + prog_bar.set_postfix(prog_postfix) + writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) + writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) + + # bootstrap value if not done + with torch.no_grad(): + next_value = agent.get_value(next_obs).reshape(1, -1) + advantages = torch.zeros_like(rewards).to(device) + lastgaelam = 0 + for t in reversed(range(args.num_steps)): + if t == args.num_steps - 1: + nextnonterminal = 1.0 - next_done + nextvalues = next_value + else: + nextnonterminal = 1.0 - dones[t + 1] + nextvalues = values[t + 1] + delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t] + advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam + returns = advantages + values + + # flatten the batch + b_obs = obs.reshape((-1,) + envs.single_observation_space.shape) + b_logprobs = logprobs.reshape(-1) + b_actions = actions.reshape((-1,) + envs.single_action_space.shape) + b_advantages = advantages.reshape(-1) + b_returns = returns.reshape(-1) + b_values = values.reshape(-1) + + # Optimizing the policy and value network + b_inds = np.arange(args.batch_size) + clipfracs = [] + for epoch in range(args.update_epochs): + np.random.shuffle(b_inds) + for start in range(0, args.batch_size, args.minibatch_size): + end = start + args.minibatch_size + mb_inds = b_inds[start:end] + + _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions[mb_inds]) + logratio = newlogprob - b_logprobs[mb_inds] + ratio = logratio.exp() + + with torch.no_grad(): + # calculate approx_kl http://joschu.net/blog/kl-approx.html + old_approx_kl = (-logratio).mean() + approx_kl = ((ratio - 1) - logratio).mean() + clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()] + + mb_advantages = b_advantages[mb_inds] + if args.norm_adv: + mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8) + + # Policy loss + pg_loss1 = -mb_advantages * ratio + pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef) + pg_loss = torch.max(pg_loss1, pg_loss2).mean() + + # Value loss + newvalue = newvalue.view(-1) + if args.clip_vloss: + v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2 + v_clipped = b_values[mb_inds] + torch.clamp( + newvalue - b_values[mb_inds], + -args.clip_coef, + args.clip_coef, + ) + v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2 + v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) + v_loss = 0.5 * v_loss_max.mean() + else: + v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean() + + entropy_loss = entropy.mean() + loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef + + optimizer.zero_grad() + loss.backward() + nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm) + optimizer.step() + + if args.target_kl is not None and approx_kl > args.target_kl: + break + + y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy() + var_y = np.var(y_true) + explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y + + # TRY NOT TO MODIFY: record rewards for plotting purposes + writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step) + writer.add_scalar("losses/value_loss", v_loss.item(), global_step) + writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step) + writer.add_scalar("losses/entropy", entropy_loss.item(), global_step) + writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step) + writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step) + writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step) + writer.add_scalar("losses/explained_variance", explained_var, global_step) + writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) + prog_postfix.update(SPS=int(global_step / (time.time() - start_time))) + prog_bar.set_postfix(prog_postfix) + + if args.save_model: + model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model" + torch.save(agent.state_dict(), model_path) + print(f"model saved to {model_path}") + from cleanrl_utils.evals.ppo_eval import evaluate + + episodic_returns = evaluate( + model_path, + make_env, + args.env_id, + eval_episodes=10, + run_name=f"{run_name}-eval", + Model=Agent, + device=device, + gamma=args.gamma, + ) + for idx, episodic_return in enumerate(episodic_returns): + writer.add_scalar("eval/episodic_return", episodic_return, idx) + + if args.upload_model: + from cleanrl_utils.huggingface import push_to_hub + + repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}" + repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name + push_to_hub(args, episodic_returns, repo_id, "PPO", f"runs/{run_name}", f"videos/{run_name}-eval") + + envs.close() + writer.close() \ No newline at end of file diff --git a/src/neuromancer/rl/requirements.txt b/src/neuromancer/rl/requirements.txt index fa8ba411..5175284a 100644 --- a/src/neuromancer/rl/requirements.txt +++ b/src/neuromancer/rl/requirements.txt @@ -1,2 +1,3 @@ -gym -ray[rllib] \ No newline at end of file +gymnasium +tyro +tqdm \ No newline at end of file diff --git a/src/neuromancer/rl/trainer.py b/src/neuromancer/rl/trainer.py new file mode 100644 index 00000000..915060b9 --- /dev/null +++ b/src/neuromancer/rl/trainer.py @@ -0,0 +1,252 @@ +#!/usr/bin/env python3 + +import argparse +import datetime +import os +import pprint + +import numpy as np +import torch +from torch import nn +from torch.distributions import Distribution, Independent, Normal +from torch.optim.lr_scheduler import LambdaLR + +from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer +from tianshou.highlevel.env import EnvFactory, VectorEnvType +from tianshou.highlevel.logger import LoggerFactoryDefault +from tianshou.policy import PPOPolicy +from tianshou.policy.base import BasePolicy +from tianshou.trainer import OnpolicyTrainer +from tianshou.utils.net.common import ActorCritic, Net +from tianshou.utils.net.continuous import ActorProb, Critic + +from neuromancer.psl.gym import BuildingEnv + + +class BuildingEnvFactory(EnvFactory): + def __init__(self, env_type, venv_type=VectorEnvType.DUMMY): + super().__init__(venv_type) + self.env_type = env_type + + def create_env(self, mode): + return BuildingEnv(self.env_type) + + +def get_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument("--task", type=str, default="SimpleSingleZone") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--buffer-size", type=int, default=4096) + parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) + parser.add_argument("--lr", type=float, default=3e-4) + parser.add_argument("--gamma", type=float, default=0.99) + parser.add_argument("--epoch", type=int, default=100) + parser.add_argument("--step-per-epoch", type=int, default=30000) + parser.add_argument("--step-per-collect", type=int, default=2048) + parser.add_argument("--repeat-per-collect", type=int, default=10) + parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument("--training-num", type=int, default=8) + parser.add_argument("--test-num", type=int, default=10) + # ppo special + parser.add_argument("--rew-norm", type=int, default=True) + # In theory, `vf-coef` will not make any difference if using Adam optimizer. + parser.add_argument("--vf-coef", type=float, default=0.25) + parser.add_argument("--ent-coef", type=float, default=0.0) + parser.add_argument("--gae-lambda", type=float, default=0.95) + parser.add_argument("--bound-action-method", type=str, default="clip") + parser.add_argument("--lr-decay", type=int, default=True) + parser.add_argument("--max-grad-norm", type=float, default=0.5) + parser.add_argument("--eps-clip", type=float, default=0.2) + parser.add_argument("--dual-clip", type=float, default=None) + parser.add_argument("--value-clip", type=int, default=0) + parser.add_argument("--norm-adv", type=int, default=0) + parser.add_argument("--recompute-adv", type=int, default=1) + parser.add_argument("--logdir", type=str, default="log") + parser.add_argument("--render", type=float, default=0.0) + parser.add_argument( + "--device", + type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + ) + parser.add_argument("--resume-path", type=str, default=None) + parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument( + "--logger", + type=str, + default="tensorboard", + choices=["tensorboard", "wandb"], + ) + parser.add_argument("--wandb-project", type=str, default="mujoco.benchmark") + parser.add_argument( + "--watch", + default=False, + action="store_true", + help="watch the play of pre-trained policy only", + ) + return parser.parse_args() + + +def train_ppo(args: argparse.Namespace = get_args()) -> None: + env, train_envs, test_envs = BuildingEnvFactory(args.task).create_envs(1, 1) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + args.max_action = env.action_space.high[0] + print("Observations shape:", args.state_shape) + print("Actions shape:", args.action_shape) + print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high)) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + # model + net_a = Net( + args.state_shape, + hidden_sizes=args.hidden_sizes, + activation=nn.Tanh, + device=args.device, + ) + actor = ActorProb( + net_a, + args.action_shape, + unbounded=True, + device=args.device, + ).to(args.device) + net_c = Net( + args.state_shape, + hidden_sizes=args.hidden_sizes, + activation=nn.Tanh, + device=args.device, + ) + critic = Critic(net_c, device=args.device).to(args.device) + actor_critic = ActorCritic(actor, critic) + + torch.nn.init.constant_(actor.sigma_param, -0.5) + for m in actor_critic.modules(): + if isinstance(m, torch.nn.Linear): + # orthogonal initialization + torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2)) + torch.nn.init.zeros_(m.bias) + # do last policy layer scaling, this will make initial actions have (close to) + # 0 mean and std, and will help boost performances, + # see https://arxiv.org/abs/2006.05990, Fig.24 for details + for m in actor.mu.modules(): + if isinstance(m, torch.nn.Linear): + torch.nn.init.zeros_(m.bias) + m.weight.data.copy_(0.01 * m.weight.data) + + optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) + + lr_scheduler = None + if args.lr_decay: + # decay learning rate to 0 linearly + max_update_num = np.ceil(args.step_per_epoch / args.step_per_collect) * args.epoch + + lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) + + def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: + loc, scale = loc_scale + return Independent(Normal(loc, scale), 1) + + policy: PPOPolicy = PPOPolicy( + actor=actor, + critic=critic, + optim=optim, + dist_fn=dist, + discount_factor=args.gamma, + gae_lambda=args.gae_lambda, + max_grad_norm=args.max_grad_norm, + vf_coef=args.vf_coef, + ent_coef=args.ent_coef, + reward_normalization=args.rew_norm, + action_scaling=True, + action_bound_method=args.bound_action_method, + lr_scheduler=lr_scheduler, + action_space=env.action_space, + eps_clip=args.eps_clip, + value_clip=args.value_clip, + dual_clip=args.dual_clip, + advantage_normalization=args.norm_adv, + recompute_advantage=args.recompute_adv, + ) + + # load a previous policy + if args.resume_path: + ckpt = torch.load(args.resume_path, map_location=args.device) + policy.load_state_dict(ckpt["model"]) + train_envs.set_obs_rms(ckpt["obs_rms"]) + test_envs.set_obs_rms(ckpt["obs_rms"]) + print("Loaded agent from: ", args.resume_path) + + # collector + buffer: VectorReplayBuffer | ReplayBuffer + if args.training_num > 1: + buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) + else: + buffer = ReplayBuffer(args.buffer_size) + train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](policy, test_envs) + + # log + now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") + args.algo_name = "ppo" + log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) + log_path = os.path.join(args.logdir, log_name) + + # logger + logger_factory = LoggerFactoryDefault() + if args.logger == "wandb": + logger_factory.logger_type = "wandb" + logger_factory.wandb_project = args.wandb_project + else: + logger_factory.logger_type = "tensorboard" + + logger = logger_factory.create_logger( + log_dir=log_path, + experiment_name=log_name, + run_id=args.resume_id, + config_dict=vars(args), + ) + + def save_best_fn(policy: BasePolicy) -> None: + state = {"model": policy.state_dict(), "obs_rms": train_envs.get_obs_rms()} + torch.save(state, os.path.join(log_path, "policy.pth")) + + if not args.watch: + # trainer + result = OnpolicyTrainer( + policy=policy, + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + repeat_per_collect=args.repeat_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + step_per_collect=args.step_per_collect, + save_best_fn=save_best_fn, + logger=logger, + test_in_train=False, + ).run() + pprint.pprint(result) + + # Let's watch its performance! + test_envs.seed(args.seed) + test_collector.reset() + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + print(collector_stats) + + +if __name__ == "__main__": + train_ppo() + # # %% + # env = BuildingEnv("SimpleSingleZone") + + # # %% + # T = 10 + # U = env.model.get_U(T) + # for i in range(T): + # ob, r, done, _ = env.step(U[i]) + # print(U[i]) + # print(ob) + # print(r) + + # # %% diff --git a/src/neuromancer/utils.py b/src/neuromancer/utils.py index 288fbc72..ffb78897 100644 --- a/src/neuromancer/utils.py +++ b/src/neuromancer/utils.py @@ -1,4 +1,6 @@ - +import os +import random +import numpy as np import torch import functools import lightning.pytorch as pl @@ -33,4 +35,15 @@ def load_state_dict_lightning(problem, weight_path): weights = torch.load(weight_path)['state_dict'] weights = OrderedDict({key.replace('problem.', '', 1): value for key, value in weights.items()}) problem.load_state_dict(weights) - return problem \ No newline at end of file + return problem + +def seed_everything(seed): + random.seed(seed) + np.random.seed(seed) + if seed is not None: + os.environ['PYTHONHASHSEED'] = str(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.mps.manual_seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = True From d18b6ac052fb13b19210fe5635a5c0eaee0d20a8 Mon Sep 17 00:00:00 2001 From: tztsai <1213815284@qq.com> Date: Wed, 23 Oct 2024 14:39:44 +0100 Subject: [PATCH 04/17] refactor tqdm usage --- src/neuromancer/rl/ppo.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/neuromancer/rl/ppo.py b/src/neuromancer/rl/ppo.py index 014af275..ea22dc01 100644 --- a/src/neuromancer/rl/ppo.py +++ b/src/neuromancer/rl/ppo.py @@ -198,9 +198,12 @@ def run(): next_obs = torch.Tensor(next_obs).to(device) next_done = torch.zeros(args.num_envs).to(device) - prog_bar = tqdm.trange(1, args.num_iterations + 1, desc="PPO training") - prog_postfix = {} - for iteration in prog_bar: + def show_progress(bar=tqdm.trange(1, args.num_iterations + 1), postfix={}, **kwargs): + postfix.update(kwargs) + bar.set_postfix(postfix) + return bar + + for iteration in show_progress(): # Annealing the rate if instructed to do so. if args.anneal_lr: frac = 1.0 - (iteration - 1.0) / args.num_iterations @@ -228,8 +231,7 @@ def run(): if "final_info" in infos: for info in infos["final_info"]: if info and "episode" in info: - prog_postfix.update(steps=global_step, reward=info["episode"]["r"].mean()) - prog_bar.set_postfix(prog_postfix) + show_progress(steps=global_step, reward=info["episode"]["r"].mean()) writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) @@ -325,8 +327,7 @@ def run(): writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step) writer.add_scalar("losses/explained_variance", explained_var, global_step) writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) - prog_postfix.update(SPS=int(global_step / (time.time() - start_time))) - prog_bar.set_postfix(prog_postfix) + show_progress(SPS=int(global_step / (time.time() - start_time))) if args.save_model: model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model" From 0fcfb56c77afb66b25bd655f42d0ffeda018eed5 Mon Sep 17 00:00:00 2001 From: tztsai <1213815284@qq.com> Date: Wed, 23 Oct 2024 14:54:04 +0100 Subject: [PATCH 05/17] refactor --- src/neuromancer/psl/gym.py | 1 + src/neuromancer/rl/__init__.py | 5 ----- src/neuromancer/rl/ppo.py | 6 +++++- src/neuromancer/rl/tianshou | 1 + 4 files changed, 7 insertions(+), 6 deletions(-) create mode 160000 src/neuromancer/rl/tianshou diff --git a/src/neuromancer/psl/gym.py b/src/neuromancer/psl/gym.py index 02814fe5..07ecf030 100644 --- a/src/neuromancer/psl/gym.py +++ b/src/neuromancer/psl/gym.py @@ -70,6 +70,7 @@ def render(self, mode='human'): pass +# allow the custom envs to be directly instantiated by gym.make(env_id) for env_id in systems: register( env_id, diff --git a/src/neuromancer/rl/__init__.py b/src/neuromancer/rl/__init__.py index 9bded9e8..e69de29b 100644 --- a/src/neuromancer/rl/__init__.py +++ b/src/neuromancer/rl/__init__.py @@ -1,5 +0,0 @@ -from neuromancer.psl.gym import BuildingEnv -from neuromancer.rl.ppo import run - -if __name__ == "__main__": - run() \ No newline at end of file diff --git a/src/neuromancer/rl/ppo.py b/src/neuromancer/rl/ppo.py index ea22dc01..315eaad0 100644 --- a/src/neuromancer/rl/ppo.py +++ b/src/neuromancer/rl/ppo.py @@ -356,4 +356,8 @@ def show_progress(bar=tqdm.trange(1, args.num_iterations + 1), postfix={}, **kwa push_to_hub(args, episodic_returns, repo_id, "PPO", f"runs/{run_name}", f"videos/{run_name}-eval") envs.close() - writer.close() \ No newline at end of file + writer.close() + + +if __name__ == "__main__": + run() \ No newline at end of file diff --git a/src/neuromancer/rl/tianshou b/src/neuromancer/rl/tianshou new file mode 160000 index 00000000..2154065b --- /dev/null +++ b/src/neuromancer/rl/tianshou @@ -0,0 +1 @@ +Subproject commit 2154065bbc6830539117dba01636c50489eb11cf From a1958eacddf682ea64d6b833f5f34f436d0d00d7 Mon Sep 17 00:00:00 2001 From: tztsai <1213815284@qq.com> Date: Wed, 23 Oct 2024 14:56:19 +0100 Subject: [PATCH 06/17] remove trash --- src/neuromancer/rl/tianshou | 1 - src/neuromancer/rl/trainer.py | 252 ---------------------------------- 2 files changed, 253 deletions(-) delete mode 160000 src/neuromancer/rl/tianshou delete mode 100644 src/neuromancer/rl/trainer.py diff --git a/src/neuromancer/rl/tianshou b/src/neuromancer/rl/tianshou deleted file mode 160000 index 2154065b..00000000 --- a/src/neuromancer/rl/tianshou +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 2154065bbc6830539117dba01636c50489eb11cf diff --git a/src/neuromancer/rl/trainer.py b/src/neuromancer/rl/trainer.py deleted file mode 100644 index 915060b9..00000000 --- a/src/neuromancer/rl/trainer.py +++ /dev/null @@ -1,252 +0,0 @@ -#!/usr/bin/env python3 - -import argparse -import datetime -import os -import pprint - -import numpy as np -import torch -from torch import nn -from torch.distributions import Distribution, Independent, Normal -from torch.optim.lr_scheduler import LambdaLR - -from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer -from tianshou.highlevel.env import EnvFactory, VectorEnvType -from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import PPOPolicy -from tianshou.policy.base import BasePolicy -from tianshou.trainer import OnpolicyTrainer -from tianshou.utils.net.common import ActorCritic, Net -from tianshou.utils.net.continuous import ActorProb, Critic - -from neuromancer.psl.gym import BuildingEnv - - -class BuildingEnvFactory(EnvFactory): - def __init__(self, env_type, venv_type=VectorEnvType.DUMMY): - super().__init__(venv_type) - self.env_type = env_type - - def create_env(self, mode): - return BuildingEnv(self.env_type) - - -def get_args() -> argparse.Namespace: - parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="SimpleSingleZone") - parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--buffer-size", type=int, default=4096) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) - parser.add_argument("--lr", type=float, default=3e-4) - parser.add_argument("--gamma", type=float, default=0.99) - parser.add_argument("--epoch", type=int, default=100) - parser.add_argument("--step-per-epoch", type=int, default=30000) - parser.add_argument("--step-per-collect", type=int, default=2048) - parser.add_argument("--repeat-per-collect", type=int, default=10) - parser.add_argument("--batch-size", type=int, default=64) - parser.add_argument("--training-num", type=int, default=8) - parser.add_argument("--test-num", type=int, default=10) - # ppo special - parser.add_argument("--rew-norm", type=int, default=True) - # In theory, `vf-coef` will not make any difference if using Adam optimizer. - parser.add_argument("--vf-coef", type=float, default=0.25) - parser.add_argument("--ent-coef", type=float, default=0.0) - parser.add_argument("--gae-lambda", type=float, default=0.95) - parser.add_argument("--bound-action-method", type=str, default="clip") - parser.add_argument("--lr-decay", type=int, default=True) - parser.add_argument("--max-grad-norm", type=float, default=0.5) - parser.add_argument("--eps-clip", type=float, default=0.2) - parser.add_argument("--dual-clip", type=float, default=None) - parser.add_argument("--value-clip", type=int, default=0) - parser.add_argument("--norm-adv", type=int, default=0) - parser.add_argument("--recompute-adv", type=int, default=1) - parser.add_argument("--logdir", type=str, default="log") - parser.add_argument("--render", type=float, default=0.0) - parser.add_argument( - "--device", - type=str, - default="cuda" if torch.cuda.is_available() else "cpu", - ) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) - parser.add_argument( - "--logger", - type=str, - default="tensorboard", - choices=["tensorboard", "wandb"], - ) - parser.add_argument("--wandb-project", type=str, default="mujoco.benchmark") - parser.add_argument( - "--watch", - default=False, - action="store_true", - help="watch the play of pre-trained policy only", - ) - return parser.parse_args() - - -def train_ppo(args: argparse.Namespace = get_args()) -> None: - env, train_envs, test_envs = BuildingEnvFactory(args.task).create_envs(1, 1) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n - args.max_action = env.action_space.high[0] - print("Observations shape:", args.state_shape) - print("Actions shape:", args.action_shape) - print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high)) - # seed - np.random.seed(args.seed) - torch.manual_seed(args.seed) - # model - net_a = Net( - args.state_shape, - hidden_sizes=args.hidden_sizes, - activation=nn.Tanh, - device=args.device, - ) - actor = ActorProb( - net_a, - args.action_shape, - unbounded=True, - device=args.device, - ).to(args.device) - net_c = Net( - args.state_shape, - hidden_sizes=args.hidden_sizes, - activation=nn.Tanh, - device=args.device, - ) - critic = Critic(net_c, device=args.device).to(args.device) - actor_critic = ActorCritic(actor, critic) - - torch.nn.init.constant_(actor.sigma_param, -0.5) - for m in actor_critic.modules(): - if isinstance(m, torch.nn.Linear): - # orthogonal initialization - torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2)) - torch.nn.init.zeros_(m.bias) - # do last policy layer scaling, this will make initial actions have (close to) - # 0 mean and std, and will help boost performances, - # see https://arxiv.org/abs/2006.05990, Fig.24 for details - for m in actor.mu.modules(): - if isinstance(m, torch.nn.Linear): - torch.nn.init.zeros_(m.bias) - m.weight.data.copy_(0.01 * m.weight.data) - - optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) - - lr_scheduler = None - if args.lr_decay: - # decay learning rate to 0 linearly - max_update_num = np.ceil(args.step_per_epoch / args.step_per_collect) * args.epoch - - lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) - - def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: - loc, scale = loc_scale - return Independent(Normal(loc, scale), 1) - - policy: PPOPolicy = PPOPolicy( - actor=actor, - critic=critic, - optim=optim, - dist_fn=dist, - discount_factor=args.gamma, - gae_lambda=args.gae_lambda, - max_grad_norm=args.max_grad_norm, - vf_coef=args.vf_coef, - ent_coef=args.ent_coef, - reward_normalization=args.rew_norm, - action_scaling=True, - action_bound_method=args.bound_action_method, - lr_scheduler=lr_scheduler, - action_space=env.action_space, - eps_clip=args.eps_clip, - value_clip=args.value_clip, - dual_clip=args.dual_clip, - advantage_normalization=args.norm_adv, - recompute_advantage=args.recompute_adv, - ) - - # load a previous policy - if args.resume_path: - ckpt = torch.load(args.resume_path, map_location=args.device) - policy.load_state_dict(ckpt["model"]) - train_envs.set_obs_rms(ckpt["obs_rms"]) - test_envs.set_obs_rms(ckpt["obs_rms"]) - print("Loaded agent from: ", args.resume_path) - - # collector - buffer: VectorReplayBuffer | ReplayBuffer - if args.training_num > 1: - buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) - else: - buffer = ReplayBuffer(args.buffer_size) - train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs) - - # log - now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") - args.algo_name = "ppo" - log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) - log_path = os.path.join(args.logdir, log_name) - - # logger - logger_factory = LoggerFactoryDefault() - if args.logger == "wandb": - logger_factory.logger_type = "wandb" - logger_factory.wandb_project = args.wandb_project - else: - logger_factory.logger_type = "tensorboard" - - logger = logger_factory.create_logger( - log_dir=log_path, - experiment_name=log_name, - run_id=args.resume_id, - config_dict=vars(args), - ) - - def save_best_fn(policy: BasePolicy) -> None: - state = {"model": policy.state_dict(), "obs_rms": train_envs.get_obs_rms()} - torch.save(state, os.path.join(log_path, "policy.pth")) - - if not args.watch: - # trainer - result = OnpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=args.repeat_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - step_per_collect=args.step_per_collect, - save_best_fn=save_best_fn, - logger=logger, - test_in_train=False, - ).run() - pprint.pprint(result) - - # Let's watch its performance! - test_envs.seed(args.seed) - test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) - print(collector_stats) - - -if __name__ == "__main__": - train_ppo() - # # %% - # env = BuildingEnv("SimpleSingleZone") - - # # %% - # T = 10 - # U = env.model.get_U(T) - # for i in range(T): - # ob, r, done, _ = env.step(U[i]) - # print(U[i]) - # print(ob) - # print(r) - - # # %% From 8cd7e1c7c90b60cdcef09bac6d41c04774aa4862 Mon Sep 17 00:00:00 2001 From: tztsai <1213815284@qq.com> Date: Wed, 23 Oct 2024 15:02:08 +0100 Subject: [PATCH 07/17] add import in PPO run() to register the envs --- src/neuromancer/rl/ppo.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/neuromancer/rl/ppo.py b/src/neuromancer/rl/ppo.py index 315eaad0..89e6e886 100644 --- a/src/neuromancer/rl/ppo.py +++ b/src/neuromancer/rl/ppo.py @@ -143,6 +143,8 @@ def get_action_and_value(self, x, action=None): def run(): + from neuromancer.psl.gym import BuildingEnv # register the envs + args = tyro.cli(Args) args.batch_size = int(args.num_envs * args.num_steps) args.minibatch_size = int(args.batch_size // args.num_minibatches) From b5091f73cd56a1b9770c8356f5af63380419bcb8 Mon Sep 17 00:00:00 2001 From: tztsai <1213815284@qq.com> Date: Wed, 23 Oct 2024 16:10:07 +0100 Subject: [PATCH 08/17] update reward func --- src/neuromancer/psl/gym.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/neuromancer/psl/gym.py b/src/neuromancer/psl/gym.py index 07ecf030..c49604f0 100644 --- a/src/neuromancer/psl/gym.py +++ b/src/neuromancer/psl/gym.py @@ -45,14 +45,14 @@ def step(self, action): truncated = False return self.obs, reward, done, truncated, dict(X_rec=self.X_rec) - def reward(self, u, y, ymin=21.0, ymax=23.0): - # energy minimization - action_loss = 0.1 * np.sum(u > 0.0) + def reward(self, u, y, ymin=20.0, ymax=22.0): + # power consumption minimization (u in W) + pc_loss = 0.001 * np.sum(u) - # thermal comfort - inbound_reward = 5. * np.sum((ymin < y) & (y < ymax)) + # thermal comfort (y in °C) + comfort_reward = 1. * np.sum((ymin < y) & (y < ymax)) - return inbound_reward - action_loss + return comfort_reward - pc_loss @property def obs(self): From 636d68cc899bdb80105789700f3acac01558b0fa Mon Sep 17 00:00:00 2001 From: tztsai <1213815284@qq.com> Date: Wed, 23 Oct 2024 17:19:57 +0100 Subject: [PATCH 09/17] update reward func --- src/neuromancer/psl/gym.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/neuromancer/psl/gym.py b/src/neuromancer/psl/gym.py index c49604f0..001c5512 100644 --- a/src/neuromancer/psl/gym.py +++ b/src/neuromancer/psl/gym.py @@ -46,13 +46,16 @@ def step(self, action): return self.obs, reward, done, truncated, dict(X_rec=self.X_rec) def reward(self, u, y, ymin=20.0, ymax=22.0): - # power consumption minimization (u in W) - pc_loss = 0.001 * np.sum(u) + # energy minimization + # u[0] is the nominal mass flow rate, u[1] is the temperature difference + q = self.model.get_q(u).sum() # q is the heat flow in W + k = np.sum(u != 0.0) # number of actions + action_loss = 0.01 * q + 0.01 * k - # thermal comfort (y in °C) - comfort_reward = 1. * np.sum((ymin < y) & (y < ymax)) + # thermal comfort + comfort_reward = 5. * np.sum((ymin < y) & (y < ymax)) # y in °C - return comfort_reward - pc_loss + return comfort_reward - action_loss @property def obs(self): From 79cbc833af9da9498f44fc4703b4cabbf5ae5724 Mon Sep 17 00:00:00 2001 From: tztsai <1213815284@qq.com> Date: Thu, 24 Oct 2024 09:13:04 +0100 Subject: [PATCH 10/17] Add SAC, project scheme and diagram --- .gitignore | 2 - src/neuromancer/rl/SCHEME.md | 85 ++++++++ src/neuromancer/rl/diagram.dot | 31 +++ src/neuromancer/rl/diagram.svg | 166 +++++++++++++++ src/neuromancer/rl/ppo.py | 21 +- src/neuromancer/rl/requirements.txt | 3 +- src/neuromancer/rl/sac.py | 303 ++++++++++++++++++++++++++++ src/neuromancer/rl/utils.py | 21 ++ 8 files changed, 610 insertions(+), 22 deletions(-) create mode 100644 src/neuromancer/rl/SCHEME.md create mode 100644 src/neuromancer/rl/diagram.dot create mode 100644 src/neuromancer/rl/diagram.svg create mode 100644 src/neuromancer/rl/sac.py create mode 100644 src/neuromancer/rl/utils.py diff --git a/.gitignore b/.gitignore index 1d9869a2..a0a6c970 100644 --- a/.gitignore +++ b/.gitignore @@ -170,6 +170,4 @@ cython_debug/ building_parameters/ -*.dot -*.png runs/ \ No newline at end of file diff --git a/src/neuromancer/rl/SCHEME.md b/src/neuromancer/rl/SCHEME.md new file mode 100644 index 00000000..354f6b66 --- /dev/null +++ b/src/neuromancer/rl/SCHEME.md @@ -0,0 +1,85 @@ +### **Project Scheme: Hybrid Control System with Differential Predictive Control (DPC) and Deep Reinforcement Learning (DRL)** + +--- + +### **Objective**: +The goal is to develop a hybrid control system by combining **Differential Predictive Control (DPC)** and **Deep Reinforcement Learning (DRL)** for efficient, robust control of a complex physical system. The system integrates a model based on **Ordinary Differential Equations (ODEs)** and **Neural State Space Models (NSSMs)** to augment control policies, with actor-critic DRL to optimize long-term strategy. + +--- + +### **Components**: + +1. **Physical System (Ground Truth)**: + - A real-world system with **limited access**, due to high cost, complexity, or experimental constraints. + +2. **System Model**: + - A **system model** based on **ODEs** or **Stochastic Differential Equations (SDEs)** to capture uncertainties and perturbations. This serves as the predictive model for **DPC**. + - The model may include **neural network (NN) terms**, such as in **Universal Differential Equations (UDEs)**, trained using real-world data when available. + - **NSSMs** are used to model the system dynamics and provide future state predictions to augment the inputs to control models. + +3. **Loss Function**: + - The objective function representing system performance (e.g., tracking error, energy consumption). This drives DPC optimization and defines the DRL reward. + +4. **Policy Model (Actor Network)**: + - An NN-based **control policy** that outputs actions. First trained via **DPC**, and later improved using **DRL** (e.g., PPO or SAC). + - The policy network receives **current states** and **NSSM-predicted future states** as inputs to enable foresight in decision-making. + +5. **Value Model (Critic Network)**: + - A **critic network** used in DRL to estimate long-term returns. It also receives **augmented inputs** from current states and NSSM predictions. + +--- + +### **Workflow**: + +#### **1. Model the Physical System Using ODE**: + - **System Model**: Model the physical system's dynamics with **ODEs** (optionally incorporating stochastic elements to capture uncertainties). This serves as the **system model** for short-term control in DPC. + - **NN Components**: If necessary, use real-world data to train any **neural network terms** in the system model. + +--- + +#### **2. Gather Real-World and Simulated Data**: + - **Data Collection**: Gather real-world data from the physical system and augment it with simulated data from the ODE-based system model. + - **Dataset**: Combine both real and simulated data into a dataset for NSSM and DPC training. + +--- + +#### **3. Train the Neural State Space Model (NSSM)**: + - **NSSM Training**: Train the **NSSM** using the collected dataset. The NSSM learns to predict future states of the system from current states and control inputs. + - **Input Augmentation**: Use NSSM-predicted next states to augment the inputs to the **policy model** (in DPC and DRL) and the **value model** (in DRL). + - This enables proactive decision-making by incorporating future state predictions into control actions. + +--- + +#### **4. Pre-train the Policy Network Using Differential Predictive Control (DPC)**: + - **DPC Training**: Pre-train the policy network with **DPC**, optimizing the control actions over a finite horizon using the **system model** (based on ODEs). + - **NSSM Predictions**: Augment the policy network's inputs with **NSSM-predicted future states** to improve decision-making. + - **Respect Constraints**: Ensure that the DPC respects system constraints, such as safety limits or actuator boundaries. + +--- + +#### **5. Train Policy Network Using DRL**: + - **Policy Initialization**: Initialize the **actor network** (policy) using the DPC-trained policy for a strong starting point. + - **Stochastic Exploration**: Ensure the policy includes some stochasticity to allow for exploration beyond the DPC-optimized policy. + - **DRL Optimization**: Refine the policy using DRL methods like **PPO** or **SAC** to maximize long-term performance. + - **Reward Function**: + - Define the reward as the **difference in losses** between the DPC and DRL policies: + \[ + R = \mathcal{L}_{\text{DPC}} - \mathcal{L}_{\text{DRL}} + \] + - This encourages the RL agent to improve over the DPC baseline policy. + - **Critic Network**: Randomly initialize the **critic network**, which will be trained alongside the policy during DRL. + +--- + +### **Final Summary**: + +1. **Model the Physical System**: Use ODEs (with stochastic elements if necessary) to represent system dynamics. +2. **Gather Data**: Collect real-world and simulated data for model training and policy optimization. +3. **Train NSSM**: Train the NSSM to predict future states, augmenting inputs to the control models. +4. **Pre-train Policy with DPC**: Use DPC to pre-train the policy using the system model. +5. **Train Policy with DRL**: Refine the policy using DRL (PPO or SAC), optimizing with the reward defined as the loss difference between DPC and DRL policies. + +--- + +### **Outcome**: +This hybrid framework combines the short-term, constraint-aware optimization of **DPC** with the long-term adaptability of **DRL**. By augmenting inputs with **NSSM-predicted future states**, the system gains foresight, allowing for more proactive, robust control strategies. The reward structure, comparing DPC and DRL policy performance, ensures continual improvement over the baseline. \ No newline at end of file diff --git a/src/neuromancer/rl/diagram.dot b/src/neuromancer/rl/diagram.dot new file mode 100644 index 00000000..f00e0aaf --- /dev/null +++ b/src/neuromancer/rl/diagram.dot @@ -0,0 +1,31 @@ +digraph Hybrid_Control_System { + rankdir=LR; // Left to Right layout + node [shape=box, style=rounded]; + + // Components + physical_system [label="Physical System"]; + system_model [label="System Model (ODE/SDE)"]; + data [label="Time Series Dataset"]; + nssm [label="Neural State Space Model (NSSM)"]; + policy_model [label="Policy Model (Actor)"]; + value_model [label="Value Model (Critic)"]; + loss [label="System Loss"]; + reward [label="RL Reward"]; + + // Workflow connections + physical_system -> data [label="Generate Data"]; + physical_system -> system_model [label="Modelling"]; + system_model -> data [label="Simulate Data"]; + system_model -> loss [label="Loss Function"]; + policy_model -> system_model [label="Decision Making"]; + data -> nssm [label="NSSM Training Data"]; + data -> system_model [label="DPC Training Data"]; + nssm -> policy_model [label="Augment Inputs"]; + nssm -> value_model [label="Augment Inputs"]; + loss -> reward[label="Loss(DPC) - Loss(DRL)"]; + loss -> policy_model[label="Optimize by DPC"]; + reward -> value_model [label="Cumulative Return"]; + system_model -> value_model [label="Observation"]; + system_model -> policy_model [label="Observation"]; + value_model -> policy_model [label="Optimize by DRL"]; +} diff --git a/src/neuromancer/rl/diagram.svg b/src/neuromancer/rl/diagram.svg new file mode 100644 index 00000000..27e3155c --- /dev/null +++ b/src/neuromancer/rl/diagram.svg @@ -0,0 +1,166 @@ + + + + + + +Hybrid_Control_System + + + +physical_system + +Physical System + + + +system_model + +System Model (ODE/SDE) + + + +physical_system->system_model + + +Modelling + + + +data + +Time Series Dataset + + + +physical_system->data + + +Generate Data + + + +system_model->data + + +Simulate Data + + + +policy_model + +Policy Model (Actor) + + + +system_model->policy_model + + +Observation + + + +value_model + +Value Model (Critic) + + + +system_model->value_model + + +Observation + + + +loss_func + +Loss + + + +system_model->loss_func + + +Loss Function + + + +data->system_model + + +DPC Training Data + + + +nssm + +Neural State Space Model (NSSM) + + + +data->nssm + + +NSSM Training Data + + + +nssm->policy_model + + +Augment Inputs + + + +nssm->value_model + + +Augment Inputs + + + +policy_model->system_model + + +Decision Making + + + +value_model->policy_model + + +Optimize by DRL + + + +loss_func->policy_model + + +Optimize by DPC + + + +reward_func + +Reward + + + +loss_func->reward_func + + +Loss(DPC) - Loss(DRL) + + + +reward_func->value_model + + +Cumulative Return + + + diff --git a/src/neuromancer/rl/ppo.py b/src/neuromancer/rl/ppo.py index 89e6e886..826a0391 100644 --- a/src/neuromancer/rl/ppo.py +++ b/src/neuromancer/rl/ppo.py @@ -14,6 +14,8 @@ from torch.distributions.normal import Normal from torch.utils.tensorboard import SummaryWriter +from neuromancer.rl.utils import make_env + @dataclass class Args: @@ -85,25 +87,6 @@ class Args: """the number of iterations (computed in runtime)""" -def make_env(env_id, idx, capture_video, run_name, gamma): - def thunk(): - if capture_video and idx == 0: - env = gym.make(env_id, render_mode="rgb_array") - env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") - else: - env = gym.make(env_id) - env = gym.wrappers.FlattenObservation(env) # deal with dm_control's Dict observation space - env = gym.wrappers.RecordEpisodeStatistics(env) - env = gym.wrappers.ClipAction(env) - env = gym.wrappers.NormalizeObservation(env) - env = gym.wrappers.TransformObservation(env, lambda obs: np.clip(obs, -10, 10)) - env = gym.wrappers.NormalizeReward(env, gamma=gamma) - env = gym.wrappers.TransformReward(env, lambda reward: np.clip(reward, -10, 10)) - return env - - return thunk - - def layer_init(layer, std=np.sqrt(2), bias_const=0.0): torch.nn.init.orthogonal_(layer.weight, std) torch.nn.init.constant_(layer.bias, bias_const) diff --git a/src/neuromancer/rl/requirements.txt b/src/neuromancer/rl/requirements.txt index 5175284a..799da220 100644 --- a/src/neuromancer/rl/requirements.txt +++ b/src/neuromancer/rl/requirements.txt @@ -1,3 +1,4 @@ gymnasium tyro -tqdm \ No newline at end of file +tqdm +stable_baselines3 \ No newline at end of file diff --git a/src/neuromancer/rl/sac.py b/src/neuromancer/rl/sac.py new file mode 100644 index 00000000..b8572b3e --- /dev/null +++ b/src/neuromancer/rl/sac.py @@ -0,0 +1,303 @@ +# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/sac/#sac_continuous_actionpy +import os +import random +import time +from dataclasses import dataclass + +import gymnasium as gym +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import tyro +from stable_baselines3.common.buffers import ReplayBuffer +from torch.utils.tensorboard import SummaryWriter + +from neuromancer.rl.utils import make_env + + +@dataclass +class Args: + exp_name: str = os.path.basename(__file__)[: -len(".py")] + """the name of this experiment""" + seed: int = 1 + """seed of the experiment""" + torch_deterministic: bool = True + """if toggled, `torch.backends.cudnn.deterministic=False`""" + cuda: bool = True + """if toggled, cuda will be enabled by default""" + track: bool = False + """if toggled, this experiment will be tracked with Weights and Biases""" + wandb_project_name: str = "cleanRL" + """the wandb's project name""" + wandb_entity: str = None + """the entity (team) of wandb's project""" + capture_video: bool = False + """whether to capture videos of the agent performances (check out `videos` folder)""" + + # Algorithm specific arguments + env_id: str = "SimpleSingleZone" + """the environment id of the task""" + num_envs: int = 1 + """the number of parallel game environments""" + total_timesteps: int = 1000000 + """total timesteps of the experiments""" + buffer_size: int = int(1e6) + """the replay memory buffer size""" + gamma: float = 0.99 + """the discount factor gamma""" + tau: float = 0.005 + """target smoothing coefficient (default: 0.005)""" + batch_size: int = 256 + """the batch size of sample from the reply memory""" + learning_starts: int = 5e3 + """timestep to start learning""" + policy_lr: float = 3e-4 + """the learning rate of the policy network optimizer""" + q_lr: float = 1e-3 + """the learning rate of the Q network network optimizer""" + policy_frequency: int = 2 + """the frequency of training policy (delayed)""" + target_network_frequency: int = 1 # Denis Yarats' implementation delays this by 2. + """the frequency of updates for the target nerworks""" + alpha: float = 0.2 + """Entropy regularization coefficient.""" + autotune: bool = True + """automatic tuning of the entropy coefficient""" + + +# ALGO LOGIC: initialize agent here: +class SoftQNetwork(nn.Module): + def __init__(self, env): + super().__init__() + self.fc1 = nn.Linear(np.array(env.single_observation_space.shape).prod() + np.prod(env.single_action_space.shape), 256) + self.fc2 = nn.Linear(256, 256) + self.fc3 = nn.Linear(256, 1) + + def forward(self, x, a): + x = torch.cat([x, a], 1) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x + + +LOG_STD_MAX = 2 +LOG_STD_MIN = -5 + + +class Actor(nn.Module): + def __init__(self, env): + super().__init__() + self.fc1 = nn.Linear(np.array(env.single_observation_space.shape).prod(), 256) + self.fc2 = nn.Linear(256, 256) + self.fc_mean = nn.Linear(256, np.prod(env.single_action_space.shape)) + self.fc_logstd = nn.Linear(256, np.prod(env.single_action_space.shape)) + # action rescaling + self.register_buffer( + "action_scale", torch.tensor((env.action_space.high - env.action_space.low) / 2.0, dtype=torch.float32) + ) + self.register_buffer( + "action_bias", torch.tensor((env.action_space.high + env.action_space.low) / 2.0, dtype=torch.float32) + ) + + def forward(self, x): + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + mean = self.fc_mean(x) + log_std = self.fc_logstd(x) + log_std = torch.tanh(log_std) + log_std = LOG_STD_MIN + 0.5 * (LOG_STD_MAX - LOG_STD_MIN) * (log_std + 1) # From SpinUp / Denis Yarats + + return mean, log_std + + def get_action(self, x): + mean, log_std = self(x) + std = log_std.exp() + normal = torch.distributions.Normal(mean, std) + x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1)) + y_t = torch.tanh(x_t) + action = y_t * self.action_scale + self.action_bias + log_prob = normal.log_prob(x_t) + # Enforcing Action Bound + log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + 1e-6) + log_prob = log_prob.sum(1, keepdim=True) + mean = torch.tanh(mean) * self.action_scale + self.action_bias + return action, log_prob, mean + + +if __name__ == "__main__": + import stable_baselines3 as sb3 + from neuromancer.psl.gym import BuildingEnv # register the envs + + if sb3.__version__ < "2.0": + raise ValueError( + """Ongoing migration: run the following command to install the new dependencies: +poetry run pip install "stable_baselines3==2.0.0a1" +""" + ) + + args = tyro.cli(Args) + run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=vars(args), + name=run_name, + monitor_gym=True, + save_code=True, + ) + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + + # TRY NOT TO MODIFY: seeding + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.backends.cudnn.deterministic = args.torch_deterministic + + device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") + + # env setup + envs = gym.vector.SyncVectorEnv( + [make_env(args.env_id, i, args.capture_video, run_name, args.gamma) for i in range(args.num_envs)] + ) + assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" + + max_action = float(envs.single_action_space.high[0]) + + actor = Actor(envs).to(device) + qf1 = SoftQNetwork(envs).to(device) + qf2 = SoftQNetwork(envs).to(device) + qf1_target = SoftQNetwork(envs).to(device) + qf2_target = SoftQNetwork(envs).to(device) + qf1_target.load_state_dict(qf1.state_dict()) + qf2_target.load_state_dict(qf2.state_dict()) + q_optimizer = optim.Adam(list(qf1.parameters()) + list(qf2.parameters()), lr=args.q_lr) + actor_optimizer = optim.Adam(list(actor.parameters()), lr=args.policy_lr) + + # Automatic entropy tuning + if args.autotune: + target_entropy = -torch.prod(torch.Tensor(envs.single_action_space.shape).to(device)).item() + log_alpha = torch.zeros(1, requires_grad=True, device=device) + alpha = log_alpha.exp().item() + a_optimizer = optim.Adam([log_alpha], lr=args.q_lr) + else: + alpha = args.alpha + + envs.single_observation_space.dtype = np.float32 + rb = ReplayBuffer( + args.buffer_size, + envs.single_observation_space, + envs.single_action_space, + device, + handle_timeout_termination=False, + ) + start_time = time.time() + + # TRY NOT TO MODIFY: start the game + obs, _ = envs.reset(seed=args.seed) + for global_step in range(args.total_timesteps): + # ALGO LOGIC: put action logic here + if global_step < args.learning_starts: + actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) + else: + actions, _, _ = actor.get_action(torch.Tensor(obs).to(device)) + actions = actions.detach().cpu().numpy() + + # TRY NOT TO MODIFY: execute the game and log data. + next_obs, rewards, terminations, truncations, infos = envs.step(actions) + + # TRY NOT TO MODIFY: record rewards for plotting purposes + if "final_info" in infos: + for info in infos["final_info"]: + print(f"global_step={global_step}, episodic_return={info['episode']['r']}") + writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) + writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) + break + + # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation` + real_next_obs = next_obs.copy() + for idx, trunc in enumerate(truncations): + if trunc: + real_next_obs[idx] = infos["final_observation"][idx] + rb.add(obs, real_next_obs, actions, rewards, terminations, infos) + + # TRY NOT TO MODIFY: CRUCIAL step easy to overlook + obs = next_obs + + # ALGO LOGIC: training. + if global_step > args.learning_starts: + data = rb.sample(args.batch_size) + with torch.no_grad(): + next_state_actions, next_state_log_pi, _ = actor.get_action(data.next_observations) + qf1_next_target = qf1_target(data.next_observations, next_state_actions) + qf2_next_target = qf2_target(data.next_observations, next_state_actions) + min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - alpha * next_state_log_pi + next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * args.gamma * (min_qf_next_target).view(-1) + + qf1_a_values = qf1(data.observations, data.actions).view(-1) + qf2_a_values = qf2(data.observations, data.actions).view(-1) + qf1_loss = F.mse_loss(qf1_a_values, next_q_value) + qf2_loss = F.mse_loss(qf2_a_values, next_q_value) + qf_loss = qf1_loss + qf2_loss + + # optimize the model + q_optimizer.zero_grad() + qf_loss.backward() + q_optimizer.step() + + if global_step % args.policy_frequency == 0: # TD 3 Delayed update support + for _ in range( + args.policy_frequency + ): # compensate for the delay by doing 'actor_update_interval' instead of 1 + pi, log_pi, _ = actor.get_action(data.observations) + qf1_pi = qf1(data.observations, pi) + qf2_pi = qf2(data.observations, pi) + min_qf_pi = torch.min(qf1_pi, qf2_pi) + actor_loss = ((alpha * log_pi) - min_qf_pi).mean() + + actor_optimizer.zero_grad() + actor_loss.backward() + actor_optimizer.step() + + if args.autotune: + with torch.no_grad(): + _, log_pi, _ = actor.get_action(data.observations) + alpha_loss = (-log_alpha.exp() * (log_pi + target_entropy)).mean() + + a_optimizer.zero_grad() + alpha_loss.backward() + a_optimizer.step() + alpha = log_alpha.exp().item() + + # update the target networks + if global_step % args.target_network_frequency == 0: + for param, target_param in zip(qf1.parameters(), qf1_target.parameters()): + target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data) + for param, target_param in zip(qf2.parameters(), qf2_target.parameters()): + target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data) + + if global_step % 100 == 0: + writer.add_scalar("losses/qf1_values", qf1_a_values.mean().item(), global_step) + writer.add_scalar("losses/qf2_values", qf2_a_values.mean().item(), global_step) + writer.add_scalar("losses/qf1_loss", qf1_loss.item(), global_step) + writer.add_scalar("losses/qf2_loss", qf2_loss.item(), global_step) + writer.add_scalar("losses/qf_loss", qf_loss.item() / 2.0, global_step) + writer.add_scalar("losses/actor_loss", actor_loss.item(), global_step) + writer.add_scalar("losses/alpha", alpha, global_step) + print("SPS:", int(global_step / (time.time() - start_time))) + writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) + if args.autotune: + writer.add_scalar("losses/alpha_loss", alpha_loss.item(), global_step) + + envs.close() + writer.close() \ No newline at end of file diff --git a/src/neuromancer/rl/utils.py b/src/neuromancer/rl/utils.py new file mode 100644 index 00000000..c3a23d85 --- /dev/null +++ b/src/neuromancer/rl/utils.py @@ -0,0 +1,21 @@ +import numpy as np +import gymnasium as gym + + +def make_env(env_id, idx, capture_video, run_name, gamma): + def thunk(): + if capture_video and idx == 0: + env = gym.make(env_id, render_mode="rgb_array") + env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") + else: + env = gym.make(env_id) + env = gym.wrappers.FlattenObservation(env) # deal with dm_control's Dict observation space + env = gym.wrappers.RecordEpisodeStatistics(env) + env = gym.wrappers.ClipAction(env) + env = gym.wrappers.NormalizeObservation(env) + env = gym.wrappers.TransformObservation(env, lambda obs: np.clip(obs, -10, 10)) + env = gym.wrappers.NormalizeReward(env, gamma=gamma) + env = gym.wrappers.TransformReward(env, lambda reward: np.clip(reward, -10, 10)) + return env + + return thunk \ No newline at end of file From 14e40f45400a9fcca392c040f9fda3a6bf0659b4 Mon Sep 17 00:00:00 2001 From: tztsai <1213815284@qq.com> Date: Sat, 26 Oct 2024 21:02:23 +0100 Subject: [PATCH 11/17] replace activation function --- src/neuromancer/rl/ppo.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/neuromancer/rl/ppo.py b/src/neuromancer/rl/ppo.py index 826a0391..a5ebb3bc 100644 --- a/src/neuromancer/rl/ppo.py +++ b/src/neuromancer/rl/ppo.py @@ -98,16 +98,16 @@ def __init__(self, envs): super().__init__() self.critic = nn.Sequential( layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)), - nn.Tanh(), + nn.LeakyReLU(), layer_init(nn.Linear(64, 64)), - nn.Tanh(), + nn.LeakyReLU(), layer_init(nn.Linear(64, 1), std=1.0), ) self.actor_mean = nn.Sequential( layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)), - nn.Tanh(), + nn.LeakyReLU(), layer_init(nn.Linear(64, 64)), - nn.Tanh(), + nn.LeakyReLU(), layer_init(nn.Linear(64, np.prod(envs.single_action_space.shape)), std=0.01), ) self.actor_logstd = nn.Parameter(torch.zeros(1, np.prod(envs.single_action_space.shape))) From 21cb903a18dadaba248e01949f655264e7b8520e Mon Sep 17 00:00:00 2001 From: tztsai <1213815284@qq.com> Date: Sat, 26 Oct 2024 22:39:25 +0100 Subject: [PATCH 12/17] implement NSSM trainer for gym env --- src/neuromancer/rl/TODO.md | 20 ------ src/neuromancer/rl/nssm.py | 127 +++++++++++++++++++++++++++++++++++++ 2 files changed, 127 insertions(+), 20 deletions(-) delete mode 100644 src/neuromancer/rl/TODO.md create mode 100644 src/neuromancer/rl/nssm.py diff --git a/src/neuromancer/rl/TODO.md b/src/neuromancer/rl/TODO.md deleted file mode 100644 index 1236b368..00000000 --- a/src/neuromancer/rl/TODO.md +++ /dev/null @@ -1,20 +0,0 @@ -- DPC as a Safety Layer for PPO - In this approach, PPO is the primary control policy that interacts with the environment to maximize long-term rewards. DPC acts as a safety layer that monitors PPO’s control actions and corrects them if they violate constraints or if the system is predicted to become unstable. - How it works: - PPO generates control actions based on its learned policy, focusing on maximizing long-term rewards through exploration and learning. - DPC monitors PPO’s actions in real time. Using a neural network, DPC predicts the future states of the system over a short horizon. If PPO’s proposed action leads to unsafe or suboptimal behavior (e.g., violating constraints or causing instability), DPC overrides PPO’s action with a safer one. - Fallback mechanism: If PPO’s action is safe and within the acceptable range, it is used. If not, DPC’s optimal control action is applied instead. -- DPC for Real-Time Control, PPO for Long-Term Policy Learning - In this approach, DPC is used for immediate predictive control, ensuring that the system adheres to constraints and is optimized by immediate feedback. PPO is responsible for learning the long-term control policy, helping the system adapt to changes and improve its performance over time. The control policy learned by PPO can guide or enhance the decisions made by DPC. - How it works: - DPC handles short-term optimization: At each time step, DPC uses a neural network to predict the system's future states over a short horizon and computes the optimal control action that minimizes a cost function while respecting constraints. - PPO updates the long-term policy: Over time, PPO learns a control policy that maximizes cumulative rewards by interacting with the environment. PPO can provide feedback to DPC in the form of improved control actions or policy adjustments. - Policy blending: You can blend the control policies from PPO and DPC by weighting them. -- PPO for Model Learning in DPC - In this approach, PPO is used to improve the neural network model used in DPC. While DPC typically relies on a pre-trained neural network to predict future states, PPO can continuously update and refine this model based on its interactions with the environment. - How it works: - DPC predicts short-term states using a neural network, and it computes optimal control actions based on these predictions. - PPO refines the neural network model: As PPO interacts with the environment, it improves its understanding of the system’s dynamics. PPO can then update the neural network used by DPC, making the predictions more accurate and improving DPC’s control performance. - Online learning: PPO continuously learns from real-time data, allowing DPC to adapt to changing system dynamics, external disturbances, or shifts in the environment. - - [ ] train DPC first, and then train the critic network using DPC policy, and use the DPC policy as initialization of PPO policy - - [ ] Use NSSM to predict the next model state and add it as extra input to the policy model. \ No newline at end of file diff --git a/src/neuromancer/rl/nssm.py b/src/neuromancer/rl/nssm.py new file mode 100644 index 00000000..baec0e74 --- /dev/null +++ b/src/neuromancer/rl/nssm.py @@ -0,0 +1,127 @@ +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +from neuromancer.dataset import DictDataset +from neuromancer.system import Node, System +from neuromancer.trainer import Trainer +from neuromancer.problem import Problem +from neuromancer.constraint import variable +from neuromancer.loss import PenaltyLoss +from neuromancer.modules import blocks +from neuromancer.psl.gym import BuildingEnv + + +class SSM(nn.Module): + """ + Baseline class for (neural) state space model (SSM) + Implements discrete-time dynamical system: + x_k+1 = fx(x_k) + fu(u_k) + fd(d_k) + with variables: + x_k - states + u_k - control inputs + """ + def __init__(self, fx, fu, fd, nx, nu, nd): + super().__init__() + self.fx, self.fu, self.fd = fx, fu, fd + self.nx, self.nu, self.nd = nx, nu, nd + self.in_features, self.out_features = nx+nu+nd, nx + + def forward(self, x, u, d): + """ + :param x: (torch.Tensor, shape=[batchsize, nx]) + :param u: (torch.Tensor, shape=[batchsize, nu]) + :return: (torch.Tensor, shape=[batchsize, outsize]) + """ + # state space model + x = self.fx(x) + self.fu(u) + self.fd(d) + return x + + +class NSSMTrainer: + def __init__(self, env, hsizes=[64, 64], lr=1e-3, batch_size=100, epochs=1000): + self.env = env + self.lr = lr + self.hsizes = hsizes + self.batch_size = batch_size + self.epochs = epochs + + dl = self.get_simulation_data(1, 1, self.env.model.ts) + ny = dl.dataset[0]['Y'].shape[-1] + nu = dl.dataset[0]['U'].shape[-1] + nd = dl.dataset[0]['D'].shape[-1] + + fx = blocks.MLP(ny, ny, bias=True, linear_map=torch.nn.Linear, + nonlin=torch.nn.ReLU, hsizes=self.hsizes) + fu = blocks.MLP(nu, ny, bias=True, linear_map=torch.nn.Linear, + nonlin=torch.nn.ReLU, hsizes=self.hsizes) + fd = blocks.MLP(nd, ny, bias=True, linear_map=torch.nn.Linear, + nonlin=torch.nn.ReLU, hsizes=self.hsizes) + + ssm = SSM(fx, fu, fd, ny, nu, nd) + self.model = Node(ssm, ['yn', 'U', 'D'], ['yn'], name='NSSM') + + y = variable("Y") + yhat = variable('yn')[:, :-1, :] + + reference_loss = 10.*(yhat == y)^2 + reference_loss.name = "ref_loss" + + onestep_loss = 1.*(yhat[:, 1, :] == y[:, 1, :])^2 + onestep_loss.name = "onestep_loss" + + objectives = [reference_loss, onestep_loss] + constraints = [] + self.loss = PenaltyLoss(objectives, constraints) + + def normalize(self, x, mean, std): + return (x - mean) / std + + def get_simulation_data(self, nsim, nsteps, ts, name='data'): + nsim = nsim // nsteps * nsteps + sim = self.env.model.simulate(nsim=nsim, ts=ts) + sim = {k: sim[k] for k in ['X', 'Y', 'U', 'D']} + nbatches = nsim // nsteps + for key in sim: + m = self.env.model.stats[key]['mean'] + s = self.env.model.stats[key]['std'] + x = self.normalize(sim[key], m, s).reshape(nbatches, nsteps, -1) + x = torch.tensor(x, dtype=torch.float32) + sim[key] = x + sim['yn'] = sim['Y'][:, :1, :] + ds = DictDataset(sim, name=name) + return DataLoader(ds, batch_size=self.batch_size, collate_fn=ds.collate_fn, shuffle=True) + + def train(self, nsim=2000, nsteps=2, niters=5): + ts = self.env.model.ts + train_loader, dev_loader, test_loader = [ + self.get_simulation_data(nsim=nsim, nsteps=nsteps, ts=ts, name=name) + for name in ['train', 'dev', 'test'] + ] + + dynamics = System([self.model], name='system') + problem = Problem([dynamics], self.loss) + optimizer = torch.optim.Adam(problem.parameters(), lr=self.lr) + + trainer = Trainer(problem, train_loader, dev_loader, test_loader, optimizer, + patience=100, warmup=100, epochs=self.epochs) + + for i in range(niters): # curriculum learning + print(f'Training with nsteps={nsteps}') + best_model = trainer.train() + print({k: float(v) for k, v in trainer.test(best_model).items() if 'loss' in k}) + if i == niters - 1: + break + nsteps *= 2 + trainer.train_data, trainer.dev_data, trainer.test_data = [ + self.get_simulation_data(nsim=nsim, nsteps=nsteps, ts=ts, name=name) + for name in ['train', 'dev', 'test'] + ] + trainer.badcount = 0 + + return best_model + + +if __name__ == '__main__': + env = BuildingEnv(simulator='SimpleSingleZone') + trainer = NSSMTrainer(env, batch_size=100, epochs=10) + dynamics_model = trainer.train(nsim=2000, nsteps=2) \ No newline at end of file From e4fc47e6de7f746511f2f8adf9370f4a8fdbdcfd Mon Sep 17 00:00:00 2001 From: tztsai <1213815284@qq.com> Date: Sat, 26 Oct 2024 22:49:17 +0100 Subject: [PATCH 13/17] add init draft of hybrid control --- src/neuromancer/rl/hybrid_control.py | 134 +++++++++++++++++++++++++++ 1 file changed, 134 insertions(+) create mode 100644 src/neuromancer/rl/hybrid_control.py diff --git a/src/neuromancer/rl/hybrid_control.py b/src/neuromancer/rl/hybrid_control.py new file mode 100644 index 00000000..6fefa6f1 --- /dev/null +++ b/src/neuromancer/rl/hybrid_control.py @@ -0,0 +1,134 @@ +import torch +import torch.nn as nn +import numpy as np +from torch.utils.data import DataLoader +from neuromancer.dataset import DictDataset +from neuromancer.system import Node, System +from neuromancer.modules import blocks +from neuromancer.trainer import Trainer +from neuromancer.problem import Problem +from neuromancer.constraint import variable +from neuromancer.loss import PenaltyLoss +from neuromancer.psl.gym import BuildingEnv +from neuromancer.psl.signals import step as step_signal +from neuromancer.plot import pltCL +from ppo import Agent, Args, run as ppo_run +from nssm import NSSMTrainer + +# Step 1: Define the physical system model using ODEs +env = BuildingEnv(simulator='SimpleSingleZone') +sys = env.model +nsim = 8000 +sim = sys.simulate(nsim=nsim, x0=sys.get_x0(), U=sys.get_U(nsim + 1)) + +# Step 2: Collect real-world and simulated data +nsteps = 100 +n_samples = 1000 +x_min = 18.0 +x_max = 22.0 + +list_xmin = [x_min + (x_max - x_min) * torch.rand(1, 1) * torch.ones(nsteps + 1, sys.ny) for _ in range(n_samples)] +xmin = torch.cat(list_xmin) +batched_xmin = xmin.reshape([n_samples, nsteps + 1, sys.ny]) +batched_xmax = batched_xmin + 2.0 + +list_dist = [torch.tensor(sys.get_D(nsteps)) for _ in range(n_samples)] +batched_dist = torch.stack(list_dist, dim=0) + +list_x0 = [torch.tensor(sys.get_x0().reshape(1, sys.nx)) for _ in range(n_samples)] +batched_x0 = torch.stack(list_x0, dim=0) + +train_data = DictDataset({'x': batched_x0, 'y': batched_x0[:, :, [3]], 'ymin': batched_xmin, 'ymax': batched_xmax, 'd': batched_dist}, name='train') +dev_data = DictDataset({'x': batched_x0, 'y': batched_x0[:, :, [3]], 'ymin': batched_xmin, 'ymax': batched_xmax, 'd': batched_dist}, name='dev') + +batch_size = 100 +train_loader = DataLoader(train_data, batch_size=batch_size, collate_fn=train_data.collate_fn, shuffle=False) +dev_loader = DataLoader(dev_data, batch_size=batch_size, collate_fn=dev_data.collate_fn, shuffle=False) + +# Step 3: Train the Neural State Space Model (NSSM) +nssm_trainer = NSSMTrainer(env, batch_size=100, epochs=1) +dynamics_model = nssm_trainer.train(nsim=2000, nsteps=2) + +# Step 4: Pre-train the policy network using DPC +A = torch.tensor(sys.params[2]['A']) +B = torch.tensor(sys.params[2]['Beta']) +C = torch.tensor(sys.params[2]['C']) +E = torch.tensor(sys.params[2]['E']) +umin = torch.tensor(sys.umin) +umax = torch.tensor(sys.umax) + +xnext = lambda x, u, d: x @ A.T + u @ B.T + d @ E.T +state_model = Node(xnext, ['x', 'u', 'd'], ['x'], name='SSM') +ynext = lambda x: x @ C.T +output_model = Node(ynext, ['x'], ['y'], name='y=Cx') + +dist_model = lambda d: d[:, sys.d_idx] +dist_obsv = Node(dist_model, ['d'], ['d_obsv'], name='dist_obsv') + +net = blocks.MLP_bounds(insize=sys.ny + 2 * sys.ny + sys.nd, outsize=sys.nu, hsizes=[32, 32], nonlin=nn.GELU, min=umin, max=umax) +policy = Node(net, ['y', 'ymin', 'ymax', 'd_obsv'], ['u'], name='policy') + +cl_system = System([dist_obsv, policy, state_model, output_model], nsteps=nsteps, name='cl_system') + +y = variable('y') +u = variable('u') +ymin = variable('ymin') +ymax = variable('ymax') + +action_loss = 0.01 * (u == 0.0) +du_loss = 0.1 * (u[:, :-1, :] - u[:, 1:, :] == 0.0) +state_lower_bound_penalty = 50.0 * (y > ymin) +state_upper_bound_penalty = 50.0 * (y < ymax) + +objectives = [action_loss, du_loss] +constraints = [state_lower_bound_penalty, state_upper_bound_penalty] + +loss = PenaltyLoss(objectives, constraints) +problem = Problem([cl_system], loss) + +epochs = 200 +optimizer = torch.optim.AdamW(problem.parameters(), lr=0.001) +trainer = Trainer(problem, train_loader, dev_loader, optimizer=optimizer, epochs=epochs, train_metric='train_loss', eval_metric='dev_loss', warmup=epochs) +best_model = trainer.train() +trainer.model.load_state_dict(best_model) + +# Step 5: Train the policy network using DRL +args = Args( + env_id='SimpleSingleZone', + total_timesteps=1000000, + learning_rate=3e-4, + num_envs=1, + num_steps=2048, + anneal_lr=True, + gamma=0.99, + gae_lambda=0.95, + num_minibatches=32, + update_epochs=10, + norm_adv=True, + clip_coef=0.2, + clip_vloss=True, + ent_coef=0.0, + vf_coef=0.5, + max_grad_norm=0.5, + target_kl=None +) + +ppo_run() + +# Step 6: Test the hybrid control system +nsteps_test = 2000 +np_refs = step_signal(nsteps_test + 1, 1, min=x_min, max=x_max, randsteps=5) +ymin_val = torch.tensor(np_refs, dtype=torch.float32).reshape(1, nsteps_test + 1, 1) +ymax_val = ymin_val + 2.0 +torch_dist = torch.tensor(sys.get_D(nsteps_test + 1)).unsqueeze(0) +x0 = torch.tensor(sys.get_x0()).reshape(1, 1, sys.nx) +data = {'x': x0, 'y': x0[:, :, [3]], 'ymin': ymin_val, 'ymax': ymax_val, 'd': torch_dist} +cl_system.nsteps = nsteps_test +trajectories = cl_system(data) + +Umin = umin * np.ones([nsteps_test, sys.nu]) +Umax = umax * np.ones([nsteps_test, sys.nu]) +Ymin = trajectories['ymin'].detach().reshape(nsteps_test + 1, sys.ny) +Ymax = trajectories['ymax'].detach().reshape(nsteps_test + 1, sys.ny) + +pltCL(Y=trajectories['y'].detach().reshape(nsteps_test + 1, sys.ny), R=Ymax, X=trajectories['x'].detach().reshape(nsteps_test + 1, sys.nx), D=trajectories['d'].detach().reshape(nsteps_test + 1, sys.nd), U=trajectories['u'].detach().reshape(nsteps_test, sys.nu), Umin=Umin, Umax=Umax, Ymin=Ymin, Ymax=Ymax) \ No newline at end of file From 5737b4112826692a1bfd7afcc8631b1c22072b92 Mon Sep 17 00:00:00 2001 From: tztsai <1213815284@qq.com> Date: Tue, 29 Oct 2024 16:29:53 +0000 Subject: [PATCH 14/17] Implement gym_dpc and train DPC, NSSM, and PPO for hybrid control --- src/neuromancer/psl/gym.py | 6 +- src/neuromancer/rl/gym_dpc.py | 151 ++++++++++++++++++++ src/neuromancer/rl/{nssm.py => gym_nssm.py} | 0 src/neuromancer/rl/hybrid_control.py | 60 ++------ src/neuromancer/rl/ppo.py | 6 +- 5 files changed, 167 insertions(+), 56 deletions(-) create mode 100644 src/neuromancer/rl/gym_dpc.py rename src/neuromancer/rl/{nssm.py => gym_nssm.py} (100%) diff --git a/src/neuromancer/psl/gym.py b/src/neuromancer/psl/gym.py index 001c5512..d161864d 100644 --- a/src/neuromancer/psl/gym.py +++ b/src/neuromancer/psl/gym.py @@ -20,12 +20,12 @@ class BuildingEnv(Env): ymax (float): Maximum threshold for thermal comfort. """ - def __init__(self, simulator, seed=None, fully_observable=True): + def __init__(self, simulator, seed=None, fully_observable=True, backend='numpy'): super().__init__() if isinstance(simulator, BuildingEnvelope): self.model = simulator else: - self.model = systems[simulator](seed=seed) + self.model = systems[simulator](seed=seed, backend=backend) self.action_space = spaces.Box( self.model.umin, self.model.umax, shape=self.model.umin.shape, dtype=np.float32) self.observation_space = spaces.Box( @@ -59,7 +59,7 @@ def reward(self, u, y, ymin=20.0, ymax=22.0): @property def obs(self): - return (self.y, self.x)[self.fully_observable].astype(np.float32) + return (self.y, self.x)[self.fully_observable] def reset(self, seed=None, options=None): seed_everything(seed) diff --git a/src/neuromancer/rl/gym_dpc.py b/src/neuromancer/rl/gym_dpc.py new file mode 100644 index 00000000..da857837 --- /dev/null +++ b/src/neuromancer/rl/gym_dpc.py @@ -0,0 +1,151 @@ +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +from neuromancer.dataset import DictDataset +from neuromancer.system import Node, System +from neuromancer.trainer import Trainer +from neuromancer.problem import Problem +from neuromancer.constraint import variable +from neuromancer.loss import PenaltyLoss +from neuromancer.modules import blocks +from neuromancer.modules.activations import activations +from neuromancer.psl.gym import BuildingEnv +import neuromancer.psl as psl +import numpy as np + +class DPCTrainer: + def __init__(self, env, hsizes=[32, 32], lr=0.001, batch_size=100, epochs=200): + self.env = env + self.lr = lr + self.hsizes = hsizes + self.batch_size = batch_size + self.epochs = epochs + + sys = self.env.model + + # Extract system parameters + A = torch.tensor(sys.params[2]['A']) + B = torch.tensor(sys.params[2]['Beta']) + C = torch.tensor(sys.params[2]['C']) + E = torch.tensor(sys.params[2]['E']) + umin = torch.tensor(sys.umin) + umax = torch.tensor(sys.umax) + nx = sys.nx + nu = sys.nu + nd = E.shape[1] + nd_obsv = sys.nd + ny = sys.ny + nref = sys.ny + y_idx = 3 + + # Define state-space model + xnext = lambda x, u, d: x @ A.T + u @ B.T + d @ E.T + state_model = Node(xnext, ['x', 'u', 'd'], ['x'], name='SSM') + ynext = lambda x: x @ C.T + output_model = Node(ynext, ['x'], ['y'], name='y=Cx') + + # Partially observable disturbance model + dist_model = lambda d: d[:, sys.d_idx] + dist_obsv = Node(dist_model, ['d'], ['d_obsv'], name='dist_obsv') + + # Neural net control policy + net = blocks.MLP_bounds(insize=ny + 2*nref + nd_obsv, + outsize=nu, hsizes=self.hsizes, + nonlin=activations['gelu'], + min=umin, max=umax) + policy = Node(net, ['y', 'ymin', 'ymax', 'd_obsv'], ['u'], name='policy') + + # Closed-loop system model + self.cl_system = System([dist_obsv, policy, state_model, output_model], + nsteps=100, + name='cl_system') + + # Define objectives and constraints + y = variable('y') + u = variable('u') + ymin = variable('ymin') + ymax = variable('ymax') + + action_loss = 0.01 * (u == 0.0) + du_loss = 0.1 * (u[:, :-1, :] - u[:, 1:, :] == 0.0) + state_lower_bound_penalty = 50. * (y > ymin) + state_upper_bound_penalty = 50. * (y < ymax) + + objectives = [action_loss, du_loss] + constraints = [state_lower_bound_penalty, state_upper_bound_penalty] + + # Create optimization problem + self.loss = PenaltyLoss(objectives, constraints) + self.problem = Problem([self.cl_system], self.loss) + + # Set up optimizer and trainer + self.optimizer = torch.optim.AdamW(self.problem.parameters(), lr=self.lr) + self.trainer = Trainer(self.problem, None, None, self.optimizer, epochs=self.epochs) + + def get_simulation_data(self, nsim, nsteps, ts, name='data'): + nsim = nsim // nsteps * nsteps + sim = self.env.model.simulate(nsim=nsim, ts=ts) + sim = {k: sim[k] for k in ['X', 'Y', 'U', 'D']} + nbatches = nsim // nsteps + for key in sim: + m = self.env.model.stats[key]['mean'] + s = self.env.model.stats[key]['std'] + x = self.normalize(sim[key], m, s).reshape(nbatches, nsteps, -1) + x = torch.tensor(x, dtype=torch.float32) + sim[key] = x + sim['yn'] = sim['Y'][:, :1, :] + ds = DictDataset(sim, name=name) + return DataLoader(ds, batch_size=self.batch_size, collate_fn=ds.collate_fn, shuffle=True) + + def normalize(self, x, mean, std): + return (x - mean) / std + + def train(self, nsim=2000, nsteps=2, niters=5): + ts = self.env.model.ts + train_loader, dev_loader, test_loader = [ + self.get_simulation_data(nsim=nsim, nsteps=nsteps, ts=ts, name=name) + for name in ['train', 'dev', 'test'] + ] + + self.trainer.train_data = train_loader + self.trainer.dev_data = dev_loader + self.trainer.test_data = test_loader + + for i in range(niters): + print(f'Training with nsteps={nsteps}') + best_model = self.trainer.train() + print({k: float(v) for k, v in self.trainer.test(best_model).items() if 'loss' in k}) + if i == niters - 1: + break + nsteps *= 2 + self.trainer.train_data, self.trainer.dev_data, self.trainer.test_data = [ + self.get_simulation_data(nsim=nsim, nsteps=nsteps, ts=ts, name=name) + for name in ['train', 'dev', 'test'] + ] + self.trainer.badcount = 0 + + return best_model + + def test(self, nsteps_test=2000): + sys = self.env.model + x_min = 18. + x_max = 22. + np_refs = psl.signals.step(nsteps_test+1, 1, min=x_min, max=x_max, randsteps=5) + ymin_val = torch.tensor(np_refs, dtype=torch.float32).reshape(1, nsteps_test+1, 1) + ymax_val = ymin_val + 2.0 + torch_dist = torch.tensor(sys.get_D(nsteps_test+1)).unsqueeze(0) + x0 = torch.tensor(sys.get_x0()).reshape(1, 1, sys.nx) + data = {'x': x0, + 'y': x0[:, :, [3]], + 'ymin': ymin_val, + 'ymax': ymax_val, + 'd': torch_dist} + self.cl_system.nsteps = nsteps_test + trajectories = self.cl_system(data) + return trajectories + +if __name__ == '__main__': + env = BuildingEnv(simulator='SimpleSingleZone') + trainer = DPCTrainer(env, batch_size=100, epochs=10) + best_model = trainer.train(nsim=2000, nsteps=2) + trajectories = trainer.test(nsteps_test=2000) \ No newline at end of file diff --git a/src/neuromancer/rl/nssm.py b/src/neuromancer/rl/gym_nssm.py similarity index 100% rename from src/neuromancer/rl/nssm.py rename to src/neuromancer/rl/gym_nssm.py diff --git a/src/neuromancer/rl/hybrid_control.py b/src/neuromancer/rl/hybrid_control.py index 6fefa6f1..e7f63e7d 100644 --- a/src/neuromancer/rl/hybrid_control.py +++ b/src/neuromancer/rl/hybrid_control.py @@ -13,15 +13,14 @@ from neuromancer.psl.signals import step as step_signal from neuromancer.plot import pltCL from ppo import Agent, Args, run as ppo_run -from nssm import NSSMTrainer +from neuromancer.rl.gym_nssm import NSSMTrainer +from neuromancer.rl.gym_dpc import DPCTrainer # Import the DPCTrainer class # Step 1: Define the physical system model using ODEs env = BuildingEnv(simulator='SimpleSingleZone') sys = env.model -nsim = 8000 -sim = sys.simulate(nsim=nsim, x0=sys.get_x0(), U=sys.get_U(nsim + 1)) -# Step 2: Collect real-world and simulated data +# Step 2: Collect data nsteps = 100 n_samples = 1000 x_min = 18.0 @@ -47,50 +46,11 @@ # Step 3: Train the Neural State Space Model (NSSM) nssm_trainer = NSSMTrainer(env, batch_size=100, epochs=1) -dynamics_model = nssm_trainer.train(nsim=2000, nsteps=2) +dynamics_model = nssm_trainer.train(nsim=2000, nsteps=2, niters=1) # Step 4: Pre-train the policy network using DPC -A = torch.tensor(sys.params[2]['A']) -B = torch.tensor(sys.params[2]['Beta']) -C = torch.tensor(sys.params[2]['C']) -E = torch.tensor(sys.params[2]['E']) -umin = torch.tensor(sys.umin) -umax = torch.tensor(sys.umax) - -xnext = lambda x, u, d: x @ A.T + u @ B.T + d @ E.T -state_model = Node(xnext, ['x', 'u', 'd'], ['x'], name='SSM') -ynext = lambda x: x @ C.T -output_model = Node(ynext, ['x'], ['y'], name='y=Cx') - -dist_model = lambda d: d[:, sys.d_idx] -dist_obsv = Node(dist_model, ['d'], ['d_obsv'], name='dist_obsv') - -net = blocks.MLP_bounds(insize=sys.ny + 2 * sys.ny + sys.nd, outsize=sys.nu, hsizes=[32, 32], nonlin=nn.GELU, min=umin, max=umax) -policy = Node(net, ['y', 'ymin', 'ymax', 'd_obsv'], ['u'], name='policy') - -cl_system = System([dist_obsv, policy, state_model, output_model], nsteps=nsteps, name='cl_system') - -y = variable('y') -u = variable('u') -ymin = variable('ymin') -ymax = variable('ymax') - -action_loss = 0.01 * (u == 0.0) -du_loss = 0.1 * (u[:, :-1, :] - u[:, 1:, :] == 0.0) -state_lower_bound_penalty = 50.0 * (y > ymin) -state_upper_bound_penalty = 50.0 * (y < ymax) - -objectives = [action_loss, du_loss] -constraints = [state_lower_bound_penalty, state_upper_bound_penalty] - -loss = PenaltyLoss(objectives, constraints) -problem = Problem([cl_system], loss) - -epochs = 200 -optimizer = torch.optim.AdamW(problem.parameters(), lr=0.001) -trainer = Trainer(problem, train_loader, dev_loader, optimizer=optimizer, epochs=epochs, train_metric='train_loss', eval_metric='dev_loss', warmup=epochs) -best_model = trainer.train() -trainer.model.load_state_dict(best_model) +dpc_trainer = DPCTrainer(env, batch_size=100, epochs=200) +best_model = dpc_trainer.train(nsim=2000, nsteps=2, niters=5) # Step 5: Train the policy network using DRL args = Args( @@ -123,11 +83,11 @@ torch_dist = torch.tensor(sys.get_D(nsteps_test + 1)).unsqueeze(0) x0 = torch.tensor(sys.get_x0()).reshape(1, 1, sys.nx) data = {'x': x0, 'y': x0[:, :, [3]], 'ymin': ymin_val, 'ymax': ymax_val, 'd': torch_dist} -cl_system.nsteps = nsteps_test -trajectories = cl_system(data) +dpc_trainer.cl_system.nsteps = nsteps_test +trajectories = dpc_trainer.cl_system(data) -Umin = umin * np.ones([nsteps_test, sys.nu]) -Umax = umax * np.ones([nsteps_test, sys.nu]) +Umin = dpc_trainer.env.model.umin * np.ones([nsteps_test, sys.nu]) +Umax = dpc_trainer.env.model.umax * np.ones([nsteps_test, sys.nu]) Ymin = trajectories['ymin'].detach().reshape(nsteps_test + 1, sys.ny) Ymax = trajectories['ymax'].detach().reshape(nsteps_test + 1, sys.ny) diff --git a/src/neuromancer/rl/ppo.py b/src/neuromancer/rl/ppo.py index a5ebb3bc..39114af7 100644 --- a/src/neuromancer/rl/ppo.py +++ b/src/neuromancer/rl/ppo.py @@ -94,16 +94,16 @@ def layer_init(layer, std=np.sqrt(2), bias_const=0.0): class Agent(nn.Module): - def __init__(self, envs): + def __init__(self, envs, actor=None, critic=None): super().__init__() - self.critic = nn.Sequential( + self.critic = critic or nn.Sequential( layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)), nn.LeakyReLU(), layer_init(nn.Linear(64, 64)), nn.LeakyReLU(), layer_init(nn.Linear(64, 1), std=1.0), ) - self.actor_mean = nn.Sequential( + self.actor_mean = actor or nn.Sequential( layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)), nn.LeakyReLU(), layer_init(nn.Linear(64, 64)), From 27f47595a661992e5de763a043dc96c13da6f8e3 Mon Sep 17 00:00:00 2001 From: tztsai <1213815284@qq.com> Date: Wed, 30 Oct 2024 17:36:29 +0000 Subject: [PATCH 15/17] fix dpc trainer --- src/neuromancer/rl/{SCHEME.md => README.md} | 0 src/neuromancer/rl/gym_dpc.py | 201 +++++++++++++------- src/neuromancer/rl/gym_nssm.py | 6 +- 3 files changed, 137 insertions(+), 70 deletions(-) rename src/neuromancer/rl/{SCHEME.md => README.md} (100%) diff --git a/src/neuromancer/rl/SCHEME.md b/src/neuromancer/rl/README.md similarity index 100% rename from src/neuromancer/rl/SCHEME.md rename to src/neuromancer/rl/README.md diff --git a/src/neuromancer/rl/gym_dpc.py b/src/neuromancer/rl/gym_dpc.py index da857837..d6dead88 100644 --- a/src/neuromancer/rl/gym_dpc.py +++ b/src/neuromancer/rl/gym_dpc.py @@ -1,3 +1,4 @@ +import numpy as np import torch import torch.nn as nn from torch.utils.data import DataLoader @@ -10,53 +11,60 @@ from neuromancer.modules import blocks from neuromancer.modules.activations import activations from neuromancer.psl.gym import BuildingEnv +from neuromancer.plot import pltCL import neuromancer.psl as psl -import numpy as np class DPCTrainer: - def __init__(self, env, hsizes=[32, 32], lr=0.001, batch_size=100, epochs=200): + def __init__(self, env, hsizes=[32, 32], lr=0.001, batch_size=100, epochs=200, xlim=(18, 22), nssm=None): self.env = env self.lr = lr self.hsizes = hsizes self.batch_size = batch_size self.epochs = epochs + self.xmin, self.xmax = xlim sys = self.env.model # Extract system parameters - A = torch.tensor(sys.params[2]['A']) - B = torch.tensor(sys.params[2]['Beta']) - C = torch.tensor(sys.params[2]['C']) - E = torch.tensor(sys.params[2]['E']) umin = torch.tensor(sys.umin) umax = torch.tensor(sys.umax) nx = sys.nx nu = sys.nu - nd = E.shape[1] nd_obsv = sys.nd ny = sys.ny nref = sys.ny - y_idx = 3 # Define state-space model - xnext = lambda x, u, d: x @ A.T + u @ B.T + d @ E.T - state_model = Node(xnext, ['x', 'u', 'd'], ['x'], name='SSM') - ynext = lambda x: x @ C.T - output_model = Node(ynext, ['x'], ['y'], name='y=Cx') + state_model = Node(sys, ['x', 'u', 'd'], ['x', 'y'], name='SSM') # Partially observable disturbance model dist_model = lambda d: d[:, sys.d_idx] dist_obsv = Node(dist_model, ['d'], ['d_obsv'], name='dist_obsv') + + insize = ny + 2*nref + nd_obsv + invars = ['y', 'ymin', 'ymax', 'd_obsv'] + + # Augment input features with NSSM estimation + if nssm is not None: + assert isinstance(nssm, nn.Module) + nssm.eval() + nssm = Node(nssm, ['x', 'u', 'd'], ['xh'], name='NSSM') + insize += nx + invars.append('xh') # Neural net control policy - net = blocks.MLP_bounds(insize=ny + 2*nref + nd_obsv, + net = blocks.MLP_bounds(insize=insize, outsize=nu, hsizes=self.hsizes, nonlin=activations['gelu'], min=umin, max=umax) - policy = Node(net, ['y', 'ymin', 'ymax', 'd_obsv'], ['u'], name='policy') + policy = Node(net, invars, ['u'], name='policy') # Closed-loop system model - self.cl_system = System([dist_obsv, policy, state_model, output_model], + if nssm is not None: + nodes = [dist_obsv, policy, nssm, state_model] + else: + nodes = [dist_obsv, policy, state_model] + self.cl_system = System(nodes, nsteps=100, name='cl_system') @@ -82,70 +90,129 @@ def __init__(self, env, hsizes=[32, 32], lr=0.001, batch_size=100, epochs=200): self.optimizer = torch.optim.AdamW(self.problem.parameters(), lr=self.lr) self.trainer = Trainer(self.problem, None, None, self.optimizer, epochs=self.epochs) - def get_simulation_data(self, nsim, nsteps, ts, name='data'): - nsim = nsim // nsteps * nsteps - sim = self.env.model.simulate(nsim=nsim, ts=ts) - sim = {k: sim[k] for k in ['X', 'Y', 'U', 'D']} - nbatches = nsim // nsteps - for key in sim: - m = self.env.model.stats[key]['mean'] - s = self.env.model.stats[key]['std'] - x = self.normalize(sim[key], m, s).reshape(nbatches, nsteps, -1) - x = torch.tensor(x, dtype=torch.float32) - sim[key] = x - sim['yn'] = sim['Y'][:, :1, :] - ds = DictDataset(sim, name=name) - return DataLoader(ds, batch_size=self.batch_size, collate_fn=ds.collate_fn, shuffle=True) - - def normalize(self, x, mean, std): - return (x - mean) / std - - def train(self, nsim=2000, nsteps=2, niters=5): - ts = self.env.model.ts - train_loader, dev_loader, test_loader = [ - self.get_simulation_data(nsim=nsim, nsteps=nsteps, ts=ts, name=name) - for name in ['train', 'dev', 'test'] - ] - - self.trainer.train_data = train_loader - self.trainer.dev_data = dev_loader - self.trainer.test_data = test_loader - - for i in range(niters): - print(f'Training with nsteps={nsteps}') - best_model = self.trainer.train() - print({k: float(v) for k, v in self.trainer.test(best_model).items() if 'loss' in k}) - if i == niters - 1: - break - nsteps *= 2 - self.trainer.train_data, self.trainer.dev_data, self.trainer.test_data = [ - self.get_simulation_data(nsim=nsim, nsteps=nsteps, ts=ts, name=name) - for name in ['train', 'dev', 'test'] - ] - self.trainer.badcount = 0 - - return best_model + def get_simulation_data(self, nsim, nsteps, n_samples): + sys = self.env.model + nx = sys.nx + nref = sys.ny + y_idx = 3 + x_min, x_max = self.xmin, self.xmax + + # sampled references for training the policy + list_xmin = [x_min+(x_max-x_min)*torch.rand(1, 1)*torch.ones(nsteps+1, nref) + for k in range(n_samples)] + xmin = torch.cat(list_xmin) + batched_xmin = xmin.reshape([n_samples, nsteps+1, nref]) + batched_xmax = batched_xmin+2.0 + # get sampled disturbance trajectories from the simulation model + list_dist = [torch.tensor(sys.get_D(nsteps)) + for k in range(n_samples)] + batched_dist = torch.stack(list_dist, dim=0) + # get sampled initial conditions + list_x0 = [torch.tensor(sys.get_x0().reshape(1, nx)) + for k in range(n_samples)] + batched_x0 = torch.stack(list_x0, dim=0) + # Training dataset + train_data = DictDataset({'x': batched_x0, + 'y': batched_x0[:, :, [y_idx]], + 'ymin': batched_xmin, + 'ymax': batched_xmax, + 'd': batched_dist}, + name='train') + + # references for dev set + list_xmin = [x_min+(x_max-x_min)*torch.rand(1, 1)*torch.ones(nsteps+1, nref) + for k in range(n_samples)] + xmin = torch.cat(list_xmin) + batched_xmin = xmin.reshape([n_samples, nsteps+1, nref]) + batched_xmax = batched_xmin+2.0 + # get sampled disturbance trajectories from the simulation model + list_dist = [torch.tensor(sys.get_D(nsteps)) + for k in range(n_samples)] + batched_dist = torch.stack(list_dist, dim=0) + # get sampled initial conditions + list_x0 = [torch.tensor(sys.get_x0().reshape(1, nx)) + for k in range(n_samples)] + batched_x0 = torch.stack(list_x0, dim=0) + # Development dataset + dev_data = DictDataset({'x': batched_x0, + 'y': batched_x0[:, :, [y_idx]], + 'ymin': batched_xmin, + 'ymax': batched_xmax, + 'd': batched_dist}, + name='dev') + + # torch dataloaders + train_loader = torch.utils.data.DataLoader(train_data, batch_size=self.batch_size, + collate_fn=train_data.collate_fn, + shuffle=False) + dev_loader = torch.utils.data.DataLoader(dev_data, batch_size=self.batch_size, + collate_fn=dev_data.collate_fn, + shuffle=False) + return train_loader, dev_loader + + def train(self, nsim=8000, nsteps=100, nsamples=1000): + train_loader, dev_loader = self.get_simulation_data(nsim, nsteps, nsamples) + # Neuromancer trainer + trainer = Trainer( + self.problem, + train_loader, dev_loader, + optimizer=self.optimizer, + epochs=self.epochs, + train_metric='train_loss', + eval_metric='dev_loss', + warmup=self.epochs, + ) + # Train control policy + best_model = trainer.train() + # load best trained model + trainer.model.load_state_dict(best_model) def test(self, nsteps_test=2000): sys = self.env.model - x_min = 18. - x_max = 22. + umin = torch.tensor(sys.umin) + umax = torch.tensor(sys.umax) + nx = sys.nx + nu = sys.nu + ny = sys.ny + nd = sys.nd + nref = sys.ny + x_min, x_max = self.xmin, self.xmax + y_idx = 3 + + # generate reference np_refs = psl.signals.step(nsteps_test+1, 1, min=x_min, max=x_max, randsteps=5) ymin_val = torch.tensor(np_refs, dtype=torch.float32).reshape(1, nsteps_test+1, 1) - ymax_val = ymin_val + 2.0 + ymax_val = ymin_val+2.0 + # generate disturbance signal torch_dist = torch.tensor(sys.get_D(nsteps_test+1)).unsqueeze(0) - x0 = torch.tensor(sys.get_x0()).reshape(1, 1, sys.nx) + # initial data for closed loop simulation + x0 = torch.tensor(sys.get_x0()).reshape(1, 1, nx) data = {'x': x0, - 'y': x0[:, :, [3]], + 'y': x0[:, :, [y_idx]], 'ymin': ymin_val, 'ymax': ymax_val, 'd': torch_dist} self.cl_system.nsteps = nsteps_test + # perform closed-loop simulation trajectories = self.cl_system(data) - return trajectories + + # constraints bounds + Umin = umin * np.ones([nsteps_test, nu]) + Umax = umax * np.ones([nsteps_test, nu]) + Ymin = trajectories['ymin'].detach().reshape(nsteps_test+1, nref) + Ymax = trajectories['ymax'].detach().reshape(nsteps_test+1, nref) + # plot closed loop trajectories + pltCL(Y=trajectories['y'].detach().reshape(nsteps_test+1, ny), + R=Ymax, + X=trajectories['x'].detach().reshape(nsteps_test+1, nx), + D=trajectories['d'].detach().reshape(nsteps_test+1, nd), + U=trajectories['u'].detach().reshape(nsteps_test, nu), + Umin=Umin, Umax=Umax, Ymin=Ymin, Ymax=Ymax) + + if __name__ == '__main__': - env = BuildingEnv(simulator='SimpleSingleZone') - trainer = DPCTrainer(env, batch_size=100, epochs=10) - best_model = trainer.train(nsim=2000, nsteps=2) + env = BuildingEnv(simulator='SimpleSingleZone', backend='torch') + trainer = DPCTrainer(env, batch_size=100, epochs=10, nssm=None) + best_model = trainer.train(nsim=100, nsteps=100, nsamples=100) trajectories = trainer.test(nsteps_test=2000) \ No newline at end of file diff --git a/src/neuromancer/rl/gym_nssm.py b/src/neuromancer/rl/gym_nssm.py index baec0e74..59499433 100644 --- a/src/neuromancer/rl/gym_nssm.py +++ b/src/neuromancer/rl/gym_nssm.py @@ -57,8 +57,7 @@ def __init__(self, env, hsizes=[64, 64], lr=1e-3, batch_size=100, epochs=1000): fd = blocks.MLP(nd, ny, bias=True, linear_map=torch.nn.Linear, nonlin=torch.nn.ReLU, hsizes=self.hsizes) - ssm = SSM(fx, fu, fd, ny, nu, nd) - self.model = Node(ssm, ['yn', 'U', 'D'], ['yn'], name='NSSM') + self.net = SSM(fx, fu, fd, ny, nu, nd) y = variable("Y") yhat = variable('yn')[:, :-1, :] @@ -98,7 +97,8 @@ def train(self, nsim=2000, nsteps=2, niters=5): for name in ['train', 'dev', 'test'] ] - dynamics = System([self.model], name='system') + model = Node(self.net, ['yn', 'U', 'D'], ['yn'], name='NSSM') + dynamics = System([model], name='system') problem = Problem([dynamics], self.loss) optimizer = torch.optim.Adam(problem.parameters(), lr=self.lr) From 9bcd8540a1f7388fc1cfedc9ecb7b3405edf03b9 Mon Sep 17 00:00:00 2001 From: tztsai <1213815284@qq.com> Date: Wed, 30 Oct 2024 17:38:39 +0000 Subject: [PATCH 16/17] fix gym_dpc testing --- src/neuromancer/rl/gym_dpc.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/neuromancer/rl/gym_dpc.py b/src/neuromancer/rl/gym_dpc.py index d6dead88..d0c9e5bc 100644 --- a/src/neuromancer/rl/gym_dpc.py +++ b/src/neuromancer/rl/gym_dpc.py @@ -181,7 +181,7 @@ def test(self, nsteps_test=2000): # generate reference np_refs = psl.signals.step(nsteps_test+1, 1, min=x_min, max=x_max, randsteps=5) - ymin_val = torch.tensor(np_refs, dtype=torch.float32).reshape(1, nsteps_test+1, 1) + ymin_val = torch.tensor(np_refs, dtype=torch.float32).reshape(1, -1, 1) ymax_val = ymin_val+2.0 # generate disturbance signal torch_dist = torch.tensor(sys.get_D(nsteps_test+1)).unsqueeze(0) @@ -199,13 +199,13 @@ def test(self, nsteps_test=2000): # constraints bounds Umin = umin * np.ones([nsteps_test, nu]) Umax = umax * np.ones([nsteps_test, nu]) - Ymin = trajectories['ymin'].detach().reshape(nsteps_test+1, nref) - Ymax = trajectories['ymax'].detach().reshape(nsteps_test+1, nref) + Ymin = trajectories['ymin'].detach().reshape(-1, nref) + Ymax = trajectories['ymax'].detach().reshape(-1, nref) # plot closed loop trajectories - pltCL(Y=trajectories['y'].detach().reshape(nsteps_test+1, ny), + pltCL(Y=trajectories['y'].detach().reshape(-1, ny), R=Ymax, - X=trajectories['x'].detach().reshape(nsteps_test+1, nx), - D=trajectories['d'].detach().reshape(nsteps_test+1, nd), + X=trajectories['x'].detach().reshape(-1, nx), + D=trajectories['d'].detach().reshape(-1, nd), U=trajectories['u'].detach().reshape(nsteps_test, nu), Umin=Umin, Umax=Umax, Ymin=Ymin, Ymax=Ymax) From 994b844f71a3aee36d6004ae0417016fe398e055 Mon Sep 17 00:00:00 2001 From: tztsai <1213815284@qq.com> Date: Wed, 30 Oct 2024 21:53:17 +0000 Subject: [PATCH 17/17] modularize PPO and make gym env obs compatible with DPC --- src/neuromancer/psl/gym.py | 40 ++-- src/neuromancer/rl/gym_dpc.py | 23 +- src/neuromancer/rl/gym_nssm.py | 2 +- src/neuromancer/rl/hybrid_control.py | 99 +++------ src/neuromancer/rl/ppo.py | 317 +++++++++++++-------------- 5 files changed, 228 insertions(+), 253 deletions(-) diff --git a/src/neuromancer/psl/gym.py b/src/neuromancer/psl/gym.py index d161864d..7ce4beef 100644 --- a/src/neuromancer/psl/gym.py +++ b/src/neuromancer/psl/gym.py @@ -1,4 +1,5 @@ import numpy as np +import torch from gymnasium import spaces, Env from gymnasium.envs.registration import register from neuromancer.utils import seed_everything @@ -20,32 +21,36 @@ class BuildingEnv(Env): ymax (float): Maximum threshold for thermal comfort. """ - def __init__(self, simulator, seed=None, fully_observable=True, backend='numpy'): + def __init__(self, simulator, seed=None, fully_observable=False, + ymin=20.0, ymax=22.0, backend='numpy'): super().__init__() if isinstance(simulator, BuildingEnvelope): self.model = simulator else: self.model = systems[simulator](seed=seed, backend=backend) + self.fully_observable = fully_observable + self.ymin = ymin + self.ymax = ymax + obs, _ = self.reset(seed=seed) self.action_space = spaces.Box( self.model.umin, self.model.umax, shape=self.model.umin.shape, dtype=np.float32) self.observation_space = spaces.Box( - -np.inf, np.inf, shape=self.model.x0.shape, dtype=np.float32) - self.fully_observable = fully_observable - self.reset(seed=seed) + -np.inf, np.inf, shape=[len(obs)], dtype=np.float32) def step(self, action): u = np.asarray(action) - d = self.model.get_D(1).flatten() + self.d = self.get_disturbance() # expect the model to accept both 1D arrays and 2D arrays - self.x, self.y = self.model(self.x, u, d) + self.x, self.y = self.model(self.x, u, self.d) self.t += 1 self.X_rec = np.append(self.X_rec, self.x) - reward = self.reward(u, self.y) + obs = self.get_obs() + reward = self.get_reward(u, self.y) done = self.t == self.model.nsim truncated = False - return self.obs, reward, done, truncated, dict(X_rec=self.X_rec) + return obs, reward, done, truncated, dict(X_rec=self.X_rec) - def reward(self, u, y, ymin=20.0, ymax=22.0): + def get_reward(self, u, y, ymin=20.0, ymax=22.0): # energy minimization # u[0] is the nominal mass flow rate, u[1] is the temperature difference q = self.model.get_q(u).sum() # q is the heat flow in W @@ -57,17 +62,24 @@ def reward(self, u, y, ymin=20.0, ymax=22.0): return comfort_reward - action_loss - @property - def obs(self): - return (self.y, self.x)[self.fully_observable] + def get_disturbance(self): + return self.model.get_D(1).flatten() + + def get_obs(self): + obs_mask = torch.as_tensor(self.model.C.flatten(), dtype=torch.bool) + self.y = self.x[obs_mask] + d = self.d if self.fully_observable else self.d[self.model.d_idx] + obs = self.x if self.fully_observable else self.y + obs = np.hstack([obs, self.ymin, self.ymax, d]) + return obs.astype(np.float32) def reset(self, seed=None, options=None): seed_everything(seed) self.t = 0 self.x = self.model.x0 - self.y = self.model.C * self.x + self.d = self.get_disturbance() self.X_rec = np.empty(shape=[0, 4]) - return self.obs, dict(X_rec=self.X_rec) + return self.get_obs(), dict(X_rec=self.X_rec) def render(self, mode='human'): pass diff --git a/src/neuromancer/rl/gym_dpc.py b/src/neuromancer/rl/gym_dpc.py index d0c9e5bc..7ae256cd 100644 --- a/src/neuromancer/rl/gym_dpc.py +++ b/src/neuromancer/rl/gym_dpc.py @@ -46,7 +46,6 @@ def __init__(self, env, hsizes=[32, 32], lr=0.001, batch_size=100, epochs=200, x # Augment input features with NSSM estimation if nssm is not None: - assert isinstance(nssm, nn.Module) nssm.eval() nssm = Node(nssm, ['x', 'u', 'd'], ['xh'], name='NSSM') insize += nx @@ -57,6 +56,7 @@ def __init__(self, env, hsizes=[32, 32], lr=0.001, batch_size=100, epochs=200, x outsize=nu, hsizes=self.hsizes, nonlin=activations['gelu'], min=umin, max=umax) + self.policy = net policy = Node(net, invars, ['u'], name='policy') # Closed-loop system model @@ -86,9 +86,8 @@ def __init__(self, env, hsizes=[32, 32], lr=0.001, batch_size=100, epochs=200, x self.loss = PenaltyLoss(objectives, constraints) self.problem = Problem([self.cl_system], self.loss) - # Set up optimizer and trainer + # Set up optimizer self.optimizer = torch.optim.AdamW(self.problem.parameters(), lr=self.lr) - self.trainer = Trainer(self.problem, None, None, self.optimizer, epochs=self.epochs) def get_simulation_data(self, nsim, nsteps, n_samples): sys = self.env.model @@ -104,12 +103,10 @@ def get_simulation_data(self, nsim, nsteps, n_samples): batched_xmin = xmin.reshape([n_samples, nsteps+1, nref]) batched_xmax = batched_xmin+2.0 # get sampled disturbance trajectories from the simulation model - list_dist = [torch.tensor(sys.get_D(nsteps)) - for k in range(n_samples)] + list_dist = [torch.as_tensor(sys.get_D(nsteps)) for k in range(n_samples)] batched_dist = torch.stack(list_dist, dim=0) # get sampled initial conditions - list_x0 = [torch.tensor(sys.get_x0().reshape(1, nx)) - for k in range(n_samples)] + list_x0 = [torch.as_tensor(sys.get_x0().reshape(1, nx)) for k in range(n_samples)] batched_x0 = torch.stack(list_x0, dim=0) # Training dataset train_data = DictDataset({'x': batched_x0, @@ -126,12 +123,10 @@ def get_simulation_data(self, nsim, nsteps, n_samples): batched_xmin = xmin.reshape([n_samples, nsteps+1, nref]) batched_xmax = batched_xmin+2.0 # get sampled disturbance trajectories from the simulation model - list_dist = [torch.tensor(sys.get_D(nsteps)) - for k in range(n_samples)] + list_dist = [torch.as_tensor(sys.get_D(nsteps)) for k in range(n_samples)] batched_dist = torch.stack(list_dist, dim=0) # get sampled initial conditions - list_x0 = [torch.tensor(sys.get_x0().reshape(1, nx)) - for k in range(n_samples)] + list_x0 = [torch.as_tensor(sys.get_x0().reshape(1, nx)) for k in range(n_samples)] batched_x0 = torch.stack(list_x0, dim=0) # Development dataset dev_data = DictDataset({'x': batched_x0, @@ -152,7 +147,7 @@ def get_simulation_data(self, nsim, nsteps, n_samples): def train(self, nsim=8000, nsteps=100, nsamples=1000): train_loader, dev_loader = self.get_simulation_data(nsim, nsteps, nsamples) - # Neuromancer trainer + trainer = Trainer( self.problem, train_loader, dev_loader, @@ -162,10 +157,14 @@ def train(self, nsim=8000, nsteps=100, nsamples=1000): eval_metric='dev_loss', warmup=self.epochs, ) + # Train control policy best_model = trainer.train() + # load best trained model trainer.model.load_state_dict(best_model) + + return trainer.model def test(self, nsteps_test=2000): sys = self.env.model diff --git a/src/neuromancer/rl/gym_nssm.py b/src/neuromancer/rl/gym_nssm.py index 59499433..c811f66c 100644 --- a/src/neuromancer/rl/gym_nssm.py +++ b/src/neuromancer/rl/gym_nssm.py @@ -84,7 +84,7 @@ def get_simulation_data(self, nsim, nsteps, ts, name='data'): m = self.env.model.stats[key]['mean'] s = self.env.model.stats[key]['std'] x = self.normalize(sim[key], m, s).reshape(nbatches, nsteps, -1) - x = torch.tensor(x, dtype=torch.float32) + x = torch.as_tensor(x, dtype=torch.float32) sim[key] = x sim['yn'] = sim['Y'][:, :1, :] ds = DictDataset(sim, name=name) diff --git a/src/neuromancer/rl/hybrid_control.py b/src/neuromancer/rl/hybrid_control.py index e7f63e7d..296e0de4 100644 --- a/src/neuromancer/rl/hybrid_control.py +++ b/src/neuromancer/rl/hybrid_control.py @@ -1,60 +1,26 @@ -import torch -import torch.nn as nn -import numpy as np -from torch.utils.data import DataLoader -from neuromancer.dataset import DictDataset -from neuromancer.system import Node, System -from neuromancer.modules import blocks -from neuromancer.trainer import Trainer -from neuromancer.problem import Problem -from neuromancer.constraint import variable -from neuromancer.loss import PenaltyLoss +import time from neuromancer.psl.gym import BuildingEnv -from neuromancer.psl.signals import step as step_signal -from neuromancer.plot import pltCL -from ppo import Agent, Args, run as ppo_run from neuromancer.rl.gym_nssm import NSSMTrainer from neuromancer.rl.gym_dpc import DPCTrainer # Import the DPCTrainer class +from neuromancer.rl.ppo import Agent, Args -# Step 1: Define the physical system model using ODEs -env = BuildingEnv(simulator='SimpleSingleZone') -sys = env.model +# Define the physical system model using ODEs +env = BuildingEnv(simulator='SimpleSingleZone', seed=1, backend='torch') -# Step 2: Collect data -nsteps = 100 -n_samples = 1000 -x_min = 18.0 -x_max = 22.0 +# Train the Neural State Space Model (NSSM) +# nssm_trainer = NSSMTrainer(env, batch_size=100, epochs=10) +# nssm_trainer.train(nsim=2000, nsteps=2) -list_xmin = [x_min + (x_max - x_min) * torch.rand(1, 1) * torch.ones(nsteps + 1, sys.ny) for _ in range(n_samples)] -xmin = torch.cat(list_xmin) -batched_xmin = xmin.reshape([n_samples, nsteps + 1, sys.ny]) -batched_xmax = batched_xmin + 2.0 +# Pre-train the policy network using DPC +dpc_trainer = DPCTrainer(env, batch_size=100, epochs=10) +dpc_trainer.train(nsim=100, nsteps=100, nsamples=100) -list_dist = [torch.tensor(sys.get_D(nsteps)) for _ in range(n_samples)] -batched_dist = torch.stack(list_dist, dim=0) +DPC_PRETRAINING = False -list_x0 = [torch.tensor(sys.get_x0().reshape(1, sys.nx)) for _ in range(n_samples)] -batched_x0 = torch.stack(list_x0, dim=0) - -train_data = DictDataset({'x': batched_x0, 'y': batched_x0[:, :, [3]], 'ymin': batched_xmin, 'ymax': batched_xmax, 'd': batched_dist}, name='train') -dev_data = DictDataset({'x': batched_x0, 'y': batched_x0[:, :, [3]], 'ymin': batched_xmin, 'ymax': batched_xmax, 'd': batched_dist}, name='dev') - -batch_size = 100 -train_loader = DataLoader(train_data, batch_size=batch_size, collate_fn=train_data.collate_fn, shuffle=False) -dev_loader = DataLoader(dev_data, batch_size=batch_size, collate_fn=dev_data.collate_fn, shuffle=False) - -# Step 3: Train the Neural State Space Model (NSSM) -nssm_trainer = NSSMTrainer(env, batch_size=100, epochs=1) -dynamics_model = nssm_trainer.train(nsim=2000, nsteps=2, niters=1) - -# Step 4: Pre-train the policy network using DPC -dpc_trainer = DPCTrainer(env, batch_size=100, epochs=200) -best_model = dpc_trainer.train(nsim=2000, nsteps=2, niters=5) - -# Step 5: Train the policy network using DRL +# Train the policy network using DRL args = Args( env_id='SimpleSingleZone', + seed=1, total_timesteps=1000000, learning_rate=3e-4, num_envs=1, @@ -73,22 +39,29 @@ target_kl=None ) -ppo_run() +args.batch_size = int(args.num_envs * args.num_steps) +args.minibatch_size = int(args.batch_size // args.num_minibatches) +args.num_iterations = args.total_timesteps // args.batch_size +run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" -# Step 6: Test the hybrid control system -nsteps_test = 2000 -np_refs = step_signal(nsteps_test + 1, 1, min=x_min, max=x_max, randsteps=5) -ymin_val = torch.tensor(np_refs, dtype=torch.float32).reshape(1, nsteps_test + 1, 1) -ymax_val = ymin_val + 2.0 -torch_dist = torch.tensor(sys.get_D(nsteps_test + 1)).unsqueeze(0) -x0 = torch.tensor(sys.get_x0()).reshape(1, 1, sys.nx) -data = {'x': x0, 'y': x0[:, :, [3]], 'ymin': ymin_val, 'ymax': ymax_val, 'd': torch_dist} -dpc_trainer.cl_system.nsteps = nsteps_test -trajectories = dpc_trainer.cl_system(data) +if args.track: + import wandb -Umin = dpc_trainer.env.model.umin * np.ones([nsteps_test, sys.nu]) -Umax = dpc_trainer.env.model.umax * np.ones([nsteps_test, sys.nu]) -Ymin = trajectories['ymin'].detach().reshape(nsteps_test + 1, sys.ny) -Ymax = trajectories['ymax'].detach().reshape(nsteps_test + 1, sys.ny) + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=vars(args), + name=run_name, + monitor_gym=True, + save_code=True, + ) -pltCL(Y=trajectories['y'].detach().reshape(nsteps_test + 1, sys.ny), R=Ymax, X=trajectories['x'].detach().reshape(nsteps_test + 1, sys.nx), D=trajectories['d'].detach().reshape(nsteps_test + 1, sys.nd), U=trajectories['u'].detach().reshape(nsteps_test, sys.nu), Umin=Umin, Umax=Umax, Ymin=Ymin, Ymax=Ymax) \ No newline at end of file +if DPC_PRETRAINING: + # load the policy model pre-trained by DPC + agent = Agent(args, actor=dpc_trainer.policy) +else: + agent = Agent(args) + +agent.train() +agent.evaluate_and_save() \ No newline at end of file diff --git a/src/neuromancer/rl/ppo.py b/src/neuromancer/rl/ppo.py index 39114af7..2c163526 100644 --- a/src/neuromancer/rl/ppo.py +++ b/src/neuromancer/rl/ppo.py @@ -15,6 +15,7 @@ from torch.utils.tensorboard import SummaryWriter from neuromancer.rl.utils import make_env +from neuromancer.utils import seed_everything @dataclass @@ -94,8 +95,20 @@ def layer_init(layer, std=np.sqrt(2), bias_const=0.0): class Agent(nn.Module): - def __init__(self, envs, actor=None, critic=None): + def __init__(self, args, envs=None, actor=None, critic=None, run_name="PPO Run"): super().__init__() + self.args = args + self.run_name = run_name + self.device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") + self.writer = SummaryWriter(f"runs/{self.run_name}") + self.writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + seed_everything(args.seed) + self.envs = envs = envs or gym.vector.SyncVectorEnv( + [make_env(args.env_id, i, args.capture_video, self.run_name, args.gamma) for i in range(args.num_envs)] + ) self.critic = critic or nn.Sequential( layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)), nn.LeakyReLU(), @@ -111,6 +124,8 @@ def __init__(self, envs, actor=None, critic=None): layer_init(nn.Linear(64, np.prod(envs.single_action_space.shape)), std=0.01), ) self.actor_logstd = nn.Parameter(torch.zeros(1, np.prod(envs.single_action_space.shape))) + self.optimizer = optim.Adam(self.parameters(), lr=args.learning_rate, eps=1e-5) + self.to(self.device) def get_value(self, x): return self.critic(x) @@ -124,162 +139,112 @@ def get_action_and_value(self, x, action=None): action = probs.sample() return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x) + def train(self): + obs = torch.zeros((self.args.num_steps, self.args.num_envs) + self.envs.single_observation_space.shape).to(self.device) + actions = torch.zeros((self.args.num_steps, self.args.num_envs) + self.envs.single_action_space.shape).to(self.device) + logprobs = torch.zeros((self.args.num_steps, self.args.num_envs)).to(self.device) + rewards = torch.zeros((self.args.num_steps, self.args.num_envs)).to(self.device) + dones = torch.zeros((self.args.num_steps, self.args.num_envs)).to(self.device) + values = torch.zeros((self.args.num_steps, self.args.num_envs)).to(self.device) + + global_step = 0 + start_time = time.time() + next_obs, _ = self.envs.reset(seed=self.args.seed) + next_obs = torch.Tensor(next_obs).to(self.device) + next_done = torch.zeros(self.args.num_envs).to(self.device) + + pbar = tqdm.trange(1, self.args.num_iterations + 1) + def show_progress(postfix={}, **kwargs): + postfix.update(kwargs) + pbar.set_postfix(postfix) + + for iteration in pbar: + if self.args.anneal_lr: + frac = 1.0 - (iteration - 1.0) / self.args.num_iterations + lrnow = frac * self.args.learning_rate + self.optimizer.param_groups[0]["lr"] = lrnow + + for step in range(0, self.args.num_steps): + global_step += self.args.num_envs + obs[step] = next_obs + dones[step] = next_done -def run(): - from neuromancer.psl.gym import BuildingEnv # register the envs - - args = tyro.cli(Args) - args.batch_size = int(args.num_envs * args.num_steps) - args.minibatch_size = int(args.batch_size // args.num_minibatches) - args.num_iterations = args.total_timesteps // args.batch_size - run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" - if args.track: - import wandb - - wandb.init( - project=args.wandb_project_name, - entity=args.wandb_entity, - sync_tensorboard=True, - config=vars(args), - name=run_name, - monitor_gym=True, - save_code=True, - ) - writer = SummaryWriter(f"runs/{run_name}") - writer.add_text( - "hyperparameters", - "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), - ) - - # TRY NOT TO MODIFY: seeding - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.backends.cudnn.deterministic = args.torch_deterministic - - device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") - - # env setup - envs = gym.vector.SyncVectorEnv( - [make_env(args.env_id, i, args.capture_video, run_name, args.gamma) for i in range(args.num_envs)] - ) - assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" - - agent = Agent(envs).to(device) - optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5) - - # ALGO Logic: Storage setup - obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device) - actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device) - logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device) - rewards = torch.zeros((args.num_steps, args.num_envs)).to(device) - dones = torch.zeros((args.num_steps, args.num_envs)).to(device) - values = torch.zeros((args.num_steps, args.num_envs)).to(device) - - # TRY NOT TO MODIFY: start the game - global_step = 0 - start_time = time.time() - next_obs, _ = envs.reset(seed=args.seed) - next_obs = torch.Tensor(next_obs).to(device) - next_done = torch.zeros(args.num_envs).to(device) - - def show_progress(bar=tqdm.trange(1, args.num_iterations + 1), postfix={}, **kwargs): - postfix.update(kwargs) - bar.set_postfix(postfix) - return bar - - for iteration in show_progress(): - # Annealing the rate if instructed to do so. - if args.anneal_lr: - frac = 1.0 - (iteration - 1.0) / args.num_iterations - lrnow = frac * args.learning_rate - optimizer.param_groups[0]["lr"] = lrnow - - for step in range(0, args.num_steps): - global_step += args.num_envs - obs[step] = next_obs - dones[step] = next_done - - # ALGO LOGIC: action logic - with torch.no_grad(): - action, logprob, _, value = agent.get_action_and_value(next_obs) - values[step] = value.flatten() - actions[step] = action - logprobs[step] = logprob - - # TRY NOT TO MODIFY: execute the game and log data. - next_obs, reward, terminations, truncations, infos = envs.step(action.cpu().numpy()) - next_done = np.logical_or(terminations, truncations) - rewards[step] = torch.tensor(reward).to(device).view(-1) - next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(next_done).to(device) - - if "final_info" in infos: - for info in infos["final_info"]: - if info and "episode" in info: - show_progress(steps=global_step, reward=info["episode"]["r"].mean()) - writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) - writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) - - # bootstrap value if not done + with torch.no_grad(): + action, logprob, _, value = self.get_action_and_value(next_obs) + values[step] = value.flatten() + actions[step] = action + logprobs[step] = logprob + + next_obs, reward, terminations, truncations, infos = self.envs.step(action.cpu().numpy()) + next_done = np.logical_or(terminations, truncations) + rewards[step] = torch.tensor(reward).to(self.device).view(-1) + next_obs, next_done = torch.Tensor(next_obs).to(self.device), torch.Tensor(next_done).to(self.device) + + if "final_info" in infos: + for info in infos["final_info"]: + if info and "episode" in info: + show_progress(steps=global_step, reward=info["episode"]["r"].mean()) + self.writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) + self.writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) + + self.optimize(obs, actions, logprobs, rewards, dones, values, next_obs, next_done, global_step, start_time) + show_progress(SPS=int(global_step / (time.time() - start_time))) + + def optimize(self, obs, actions, logprobs, rewards, dones, values, next_obs, next_done, global_step, start_time): with torch.no_grad(): - next_value = agent.get_value(next_obs).reshape(1, -1) - advantages = torch.zeros_like(rewards).to(device) + next_value = self.get_value(next_obs).reshape(1, -1) + advantages = torch.zeros_like(rewards).to(self.device) lastgaelam = 0 - for t in reversed(range(args.num_steps)): - if t == args.num_steps - 1: + for t in reversed(range(self.args.num_steps)): + if t == self.args.num_steps - 1: nextnonterminal = 1.0 - next_done nextvalues = next_value else: nextnonterminal = 1.0 - dones[t + 1] nextvalues = values[t + 1] - delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t] - advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam + delta = rewards[t] + self.args.gamma * nextvalues * nextnonterminal - values[t] + advantages[t] = lastgaelam = delta + self.args.gamma * self.args.gae_lambda * nextnonterminal * lastgaelam returns = advantages + values - # flatten the batch - b_obs = obs.reshape((-1,) + envs.single_observation_space.shape) + b_obs = obs.reshape((-1,) + self.envs.single_observation_space.shape) b_logprobs = logprobs.reshape(-1) - b_actions = actions.reshape((-1,) + envs.single_action_space.shape) + b_actions = actions.reshape((-1,) + self.envs.single_action_space.shape) b_advantages = advantages.reshape(-1) b_returns = returns.reshape(-1) b_values = values.reshape(-1) - # Optimizing the policy and value network - b_inds = np.arange(args.batch_size) + b_inds = np.arange(self.args.batch_size) clipfracs = [] - for epoch in range(args.update_epochs): + for epoch in range(self.args.update_epochs): np.random.shuffle(b_inds) - for start in range(0, args.batch_size, args.minibatch_size): - end = start + args.minibatch_size + for start in range(0, self.args.batch_size, self.args.minibatch_size): + end = start + self.args.minibatch_size mb_inds = b_inds[start:end] - _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions[mb_inds]) + _, newlogprob, entropy, newvalue = self.get_action_and_value(b_obs[mb_inds], b_actions[mb_inds]) logratio = newlogprob - b_logprobs[mb_inds] ratio = logratio.exp() with torch.no_grad(): - # calculate approx_kl http://joschu.net/blog/kl-approx.html old_approx_kl = (-logratio).mean() approx_kl = ((ratio - 1) - logratio).mean() - clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()] + clipfracs += [((ratio - 1.0).abs() > self.args.clip_coef).float().mean().item()] mb_advantages = b_advantages[mb_inds] - if args.norm_adv: + if self.args.norm_adv: mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8) - # Policy loss pg_loss1 = -mb_advantages * ratio - pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef) + pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - self.args.clip_coef, 1 + self.args.clip_coef) pg_loss = torch.max(pg_loss1, pg_loss2).mean() - # Value loss newvalue = newvalue.view(-1) - if args.clip_vloss: + if self.args.clip_vloss: v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2 v_clipped = b_values[mb_inds] + torch.clamp( newvalue - b_values[mb_inds], - -args.clip_coef, - args.clip_coef, + -self.args.clip_coef, + self.args.clip_coef, ) v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2 v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) @@ -288,61 +253,87 @@ def show_progress(bar=tqdm.trange(1, args.num_iterations + 1), postfix={}, **kwa v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean() entropy_loss = entropy.mean() - loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef + loss = pg_loss - self.args.ent_coef * entropy_loss + v_loss * self.args.vf_coef - optimizer.zero_grad() + self.optimizer.zero_grad() loss.backward() - nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm) - optimizer.step() + nn.utils.clip_grad_norm_(self.parameters(), self.args.max_grad_norm) + self.optimizer.step() - if args.target_kl is not None and approx_kl > args.target_kl: + if self.args.target_kl is not None and approx_kl > self.args.target_kl: break y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy() var_y = np.var(y_true) explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y - # TRY NOT TO MODIFY: record rewards for plotting purposes - writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step) - writer.add_scalar("losses/value_loss", v_loss.item(), global_step) - writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step) - writer.add_scalar("losses/entropy", entropy_loss.item(), global_step) - writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step) - writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step) - writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step) - writer.add_scalar("losses/explained_variance", explained_var, global_step) - writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) - show_progress(SPS=int(global_step / (time.time() - start_time))) - - if args.save_model: - model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model" - torch.save(agent.state_dict(), model_path) - print(f"model saved to {model_path}") - from cleanrl_utils.evals.ppo_eval import evaluate - - episodic_returns = evaluate( - model_path, - make_env, - args.env_id, - eval_episodes=10, - run_name=f"{run_name}-eval", - Model=Agent, - device=device, - gamma=args.gamma, - ) - for idx, episodic_return in enumerate(episodic_returns): - writer.add_scalar("eval/episodic_return", episodic_return, idx) + self.writer.add_scalar("charts/learning_rate", self.optimizer.param_groups[0]["lr"], global_step) + self.writer.add_scalar("losses/value_loss", v_loss.item(), global_step) + self.writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step) + self.writer.add_scalar("losses/entropy", entropy_loss.item(), global_step) + self.writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step) + self.writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step) + self.writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step) + self.writer.add_scalar("losses/explained_variance", explained_var, global_step) + self.writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) + + def evaluate_and_save(self): + if self.args.save_model: + model_path = f"runs/{self.run_name}/{self.args.exp_name}.cleanrl_model" + torch.save(self.state_dict(), model_path) + print(f"model saved to {model_path}") + from cleanrl_utils.evals.ppo_eval import evaluate + + episodic_returns = evaluate( + model_path, + make_env, + self.args.env_id, + eval_episodes=10, + run_name=f"{self.run_name}-eval", + Model=Agent, + device=self.device, + gamma=self.args.gamma, + ) + for idx, episodic_return in enumerate(episodic_returns): + self.writer.add_scalar("eval/episodic_return", episodic_return, idx) + + if self.args.upload_model: + from cleanrl_utils.huggingface import push_to_hub + + repo_name = f"{self.args.env_id}-{self.args.exp_name}-seed{self.args.seed}" + repo_id = f"{self.args.hf_entity}/{repo_name}" if self.args.hf_entity else repo_name + push_to_hub(self.args, episodic_returns, repo_id, "PPO", f"runs/{self.run_name}", f"videos/{self.run_name}-eval") + + self.envs.close() + self.writer.close() - if args.upload_model: - from cleanrl_utils.huggingface import push_to_hub - repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}" - repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name - push_to_hub(args, episodic_returns, repo_id, "PPO", f"runs/{run_name}", f"videos/{run_name}-eval") +def run(): + from neuromancer.psl.gym import BuildingEnv # register the envs - envs.close() - writer.close() + args = tyro.cli(Args) + args.batch_size = int(args.num_envs * args.num_steps) + args.minibatch_size = int(args.batch_size // args.num_minibatches) + args.num_iterations = args.total_timesteps // args.batch_size + run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=vars(args), + name=run_name, + monitor_gym=True, + save_code=True, + ) + + agent = Agent(args, run_name=run_name) + agent.train() + agent.evaluate_and_save() + if __name__ == "__main__": run() \ No newline at end of file