From d0348fafd75c2d7f230dfdf36681be4b91498f39 Mon Sep 17 00:00:00 2001 From: Laurin Luttmann Date: Thu, 12 Sep 2024 22:41:04 +0200 Subject: [PATCH] [Feat] added self-labeling training algorithm --- configs/experiment/scheduling/am-ppo.yaml | 1 + configs/experiment/scheduling/base.yaml | 1 - .../experiment/scheduling/episodic-ppo.yaml | 35 ++++ configs/experiment/scheduling/gnn-ppo.yaml | 1 + configs/experiment/scheduling/hgnn-ppo.yaml | 1 + configs/experiment/scheduling/matnet-ppo.yaml | 1 + configs/experiment/scheduling/sl.yaml | 40 +++++ rl4co/models/__init__.py | 2 +- rl4co/models/rl/__init__.py | 1 + .../rl/self_supervised/self_labeling.py | 160 ++++++++++++++++++ rl4co/models/zoo/l2d/policy.py | 8 +- 11 files changed, 246 insertions(+), 5 deletions(-) create mode 100644 configs/experiment/scheduling/episodic-ppo.yaml create mode 100644 configs/experiment/scheduling/sl.yaml create mode 100644 rl4co/models/rl/self_supervised/self_labeling.py diff --git a/configs/experiment/scheduling/am-ppo.yaml b/configs/experiment/scheduling/am-ppo.yaml index f9e5d354..c980bc43 100644 --- a/configs/experiment/scheduling/am-ppo.yaml +++ b/configs/experiment/scheduling/am-ppo.yaml @@ -45,6 +45,7 @@ model: test_batch_size: 64 train_data_size: 2000 mini_batch_size: 512 + max_grad_norm: 1 env: stepwise_reward: True \ No newline at end of file diff --git a/configs/experiment/scheduling/base.yaml b/configs/experiment/scheduling/base.yaml index c15a6c45..88ebeca0 100644 --- a/configs/experiment/scheduling/base.yaml +++ b/configs/experiment/scheduling/base.yaml @@ -37,4 +37,3 @@ model: lr_scheduler_kwargs: gamma: 0.95 reward_scale: scale - max_grad_norm: 1 diff --git a/configs/experiment/scheduling/episodic-ppo.yaml b/configs/experiment/scheduling/episodic-ppo.yaml new file mode 100644 index 00000000..50107006 --- /dev/null +++ b/configs/experiment/scheduling/episodic-ppo.yaml @@ -0,0 +1,35 @@ +# @package _global_ + +defaults: + - scheduling/base + +logger: + wandb: + tags: ["hgnn-ppo", "${env.name}"] + name: "hgnn-ppo-${env.name}-${env.generator_params.num_jobs}j-${env.generator_params.num_machines}m" + +# params from Song et al. +model: + _target_: rl4co.models.L2DModel + policy_kwargs: + embed_dim: 128 + num_encoder_layers: 3 + scaling_factor: ${scaling_factor} + max_grad_norm: 1 + ppo_epochs: 3 + het_emb: True + batch_size: 128 + val_batch_size: 512 + test_batch_size: 64 + mini_batch_size: 512 + # reward_scale: scale + optimizer_kwargs: + lr: 1e-4 + +trainer: + max_epochs: 10 + + +env: + stepwise_reward: False + _torchrl_mode: False \ No newline at end of file diff --git a/configs/experiment/scheduling/gnn-ppo.yaml b/configs/experiment/scheduling/gnn-ppo.yaml index d2139eea..d965449e 100644 --- a/configs/experiment/scheduling/gnn-ppo.yaml +++ b/configs/experiment/scheduling/gnn-ppo.yaml @@ -23,6 +23,7 @@ model: val_batch_size: 512 test_batch_size: 64 mini_batch_size: 512 + max_grad_norm: 1 trainer: diff --git a/configs/experiment/scheduling/hgnn-ppo.yaml b/configs/experiment/scheduling/hgnn-ppo.yaml index 7d46f7d7..2881f1b2 100644 --- a/configs/experiment/scheduling/hgnn-ppo.yaml +++ b/configs/experiment/scheduling/hgnn-ppo.yaml @@ -22,6 +22,7 @@ model: val_batch_size: 512 test_batch_size: 64 mini_batch_size: 512 + max_grad_norm: 1 env: stepwise_reward: True \ No newline at end of file diff --git a/configs/experiment/scheduling/matnet-ppo.yaml b/configs/experiment/scheduling/matnet-ppo.yaml index c88d2c64..d1fc49f0 100644 --- a/configs/experiment/scheduling/matnet-ppo.yaml +++ b/configs/experiment/scheduling/matnet-ppo.yaml @@ -37,6 +37,7 @@ model: val_batch_size: 512 test_batch_size: 64 mini_batch_size: 512 + max_grad_norm: 1 env: stepwise_reward: True \ No newline at end of file diff --git a/configs/experiment/scheduling/sl.yaml b/configs/experiment/scheduling/sl.yaml new file mode 100644 index 00000000..838642ca --- /dev/null +++ b/configs/experiment/scheduling/sl.yaml @@ -0,0 +1,40 @@ +# @package _global_ + +defaults: + - scheduling/base + +logger: + wandb: + tags: ["matnet-pomo", "${env.name}"] + name: "matnet-pomo-${env.name}-${env.generator_params.num_jobs}j-${env.generator_params.num_machines}m" + +embed_dim: 256 + +model: + _target_: rl4co.models.SelfLabeling + policy: + _target_: rl4co.models.L2DPolicy4PPO + decoder: + _target_: rl4co.models.zoo.l2d.decoder.L2DDecoder + env_name: ${env.name} + embed_dim: ${embed_dim} + het_emb: True + feature_extractor: + _target_: rl4co.models.zoo.matnet.matnet_w_sa.Encoder + embed_dim: ${embed_dim} + num_heads: 8 + num_layers: 4 + normalization: "batch" + init_embedding: + _target_: rl4co.models.nn.env_embeddings.init.FJSPMatNetInitEmbedding + embed_dim: ${embed_dim} + scaling_factor: ${scaling_factor} + env_name: ${env.name} + embed_dim: ${embed_dim} + scaling_factor: ${scaling_factor} + het_emb: True + batch_size: 64 + num_starts: 10 + metrics: + val: ["reward", "max_reward"] + test: ${model.metrics.val} diff --git a/rl4co/models/__init__.py b/rl4co/models/__init__.py index 339c3b01..1fd65eb1 100644 --- a/rl4co/models/__init__.py +++ b/rl4co/models/__init__.py @@ -14,7 +14,7 @@ NonAutoregressivePolicy, ) from rl4co.models.common.transductive import TransductiveModel -from rl4co.models.rl import StepwisePPO +from rl4co.models.rl import SelfLabeling, StepwisePPO from rl4co.models.rl.a2c.a2c import A2C from rl4co.models.rl.common.base import RL4COLitModule from rl4co.models.rl.ppo.ppo import PPO diff --git a/rl4co/models/rl/__init__.py b/rl4co/models/rl/__init__.py index 1a3bf7e2..a24ce49e 100644 --- a/rl4co/models/rl/__init__.py +++ b/rl4co/models/rl/__init__.py @@ -4,3 +4,4 @@ from rl4co.models.rl.ppo.ppo import PPO from rl4co.models.rl.ppo.stepwise_ppo import StepwisePPO from rl4co.models.rl.reinforce.reinforce import REINFORCE +from rl4co.models.rl.self_supervised.self_labeling import SelfLabeling diff --git a/rl4co/models/rl/self_supervised/self_labeling.py b/rl4co/models/rl/self_supervised/self_labeling.py new file mode 100644 index 00000000..f797ed09 --- /dev/null +++ b/rl4co/models/rl/self_supervised/self_labeling.py @@ -0,0 +1,160 @@ +import copy + +from typing import Any, Union + +import torch +import torch.nn as nn + +from torch.nn import CrossEntropyLoss +from torchrl.data.replay_buffers import ( + LazyMemmapStorage, + ListStorage, + SamplerWithoutReplacement, + TensorDictReplayBuffer, +) + +from rl4co.envs.common.base import RL4COEnvBase +from rl4co.models.rl.common.base import RL4COLitModule +from rl4co.utils.ops import batchify, unbatchify +from rl4co.utils.pylogger import get_pylogger + +log = get_pylogger(__name__) + + +def make_replay_buffer(buffer_size, batch_size, device="cpu"): + if device == "cpu": + storage = LazyMemmapStorage(buffer_size, device="cpu") + prefetch = 3 + else: + storage = ListStorage(buffer_size) + prefetch = None + return TensorDictReplayBuffer( + storage=storage, + batch_size=batch_size, + sampler=SamplerWithoutReplacement(drop_last=True), + pin_memory=False, + prefetch=prefetch, + ) + + +class SelfLabeling(RL4COLitModule): + def __init__( + self, + env: RL4COEnvBase, + policy: nn.Module, + clip_range: float = 0.2, # epsilon of PPO + update_timestep: int = 1, + buffer_size: int = 100_000, + sl_epochs: int = 1, # inner epoch, K + batch_size: int = 256, + mini_batch_size: int = 256, + vf_lambda: float = 0.5, # lambda of Value function fitting + entropy_lambda: float = 0.01, # lambda of entropy bonus + max_grad_norm: float = 0.5, # max gradient norm + buffer_storage_device: str = "gpu", + metrics: dict = { + "train": ["loss", "surrogate_loss", "value_loss", "entropy"], + }, + reward_scale: Union[str, int] = None, + num_starts: int = None, + **kwargs, + ): + super().__init__(env, policy, metrics=metrics, batch_size=batch_size, **kwargs) + + self.policy_old = copy.deepcopy(self.policy) + self.automatic_optimization = False # PPO uses custom optimization routine + self.rb = make_replay_buffer(buffer_size, mini_batch_size, buffer_storage_device) + self.sl_epochs = sl_epochs + self.max_grad_norm = max_grad_norm + self.update_timestep = update_timestep + self.mini_batch_size = mini_batch_size + self.num_starts = num_starts + + def update(self, eval_td, device): + losses = [] + # PPO inner epoch + for _ in range(self.sl_epochs): + for sub_td in self.rb: + sub_td = sub_td.to(device) + + logprobs, _, _ = self.policy.evaluate(sub_td, return_selected=False) + + criterion = CrossEntropyLoss(reduction="mean") + # compute total loss + loss = criterion(logprobs, sub_td["action"]) + + opt = self.optimizers() + opt.zero_grad() + self.manual_backward(loss) + if self.max_grad_norm is not None: + self.clip_gradients( + opt, + gradient_clip_val=self.max_grad_norm, + gradient_clip_algorithm="norm", + ) + + opt.step() + losses.append(loss) + + # need eval for greedy decoding + out = self.policy.generate(eval_td, self.env, phase="val") + # add loss to metrics + out["loss"] = torch.stack(losses, dim=0) + return out + + def shared_step( + self, batch: Any, batch_idx: int, phase: str, dataloader_idx: int = None + ): + orig_td = self.env.reset(batch) + device = orig_td.device + n_start = ( + self.env.get_num_starts(orig_td) + if self.num_starts is None + else self.num_starts + ) + next_td = batchify(orig_td.clone(), n_start) + td_stack = [] + + if phase == "train": + while not next_td["done"].all(): + + with torch.no_grad(): + td = self.policy_old.act(next_td, self.env, phase="train") + + # get next state + next_td = self.env.step(td)["next"] + + # add tensordict with action, logprobs and reward information to buffer + td_stack.append(td) + # (bs * #samples, #steps) + td_stack = torch.stack(td_stack, dim=1) + # (bs, #samples, #steps) + td_stack_unbs = unbatchify(td_stack, n_start) + # (bs * #samples) + rewards = self.env.get_reward(next_td, None) + # (bs) + _, best_idx = unbatchify(rewards, n_start).max(dim=1) + td_best = td_stack_unbs.gather( + 1, best_idx[:, None, None].expand(-1, 1, td_stack_unbs.size(2)) + ).squeeze(1) + # flatten so that every step is an experience TODO can we enhance this? + self.rb.extend(td_best.flatten()) + + # if iter mod x = 0 then update the policy (x = 1 in paper) + if batch_idx % self.update_timestep == 0: + + out = self.update(orig_td, device) + + # TODO check the details of this: if out["reward"].mean() > max_rew.mean(): + # Copy new weights into old policy: + self.policy_old.load_state_dict(self.policy.state_dict()) + # only clear the rb if we improved on the old model, otherwise the experience is still useful + self.rb.empty() + + else: + out = self.policy.generate( + next_td, self.env, phase=phase # , select_best=True, multisample=True + ) + + metrics = self.log_metrics(out, phase, dataloader_idx=dataloader_idx) + return {"loss": out.get("loss", None), **metrics} diff --git a/rl4co/models/zoo/l2d/policy.py b/rl4co/models/zoo/l2d/policy.py index 0cfac356..07d9fb05 100644 --- a/rl4co/models/zoo/l2d/policy.py +++ b/rl4co/models/zoo/l2d/policy.py @@ -203,7 +203,7 @@ def __init__( self.encoder, NoEncoder ), "Define a feature extractor for decoder rather than an encoder in stepwise PPO" - def evaluate(self, td): + def evaluate(self, td, return_selected=True): # Encoder: get encoder output and initial embeddings from initial state hidden, _ = self.decoder.feature_extractor(td) # pool the embeddings for the critic @@ -220,10 +220,12 @@ def evaluate(self, td): logits, mask = self.decoder.actor(td, *hidden) # get logprobs and entropy over logp distribution logprobs = process_logits(logits, mask, tanh_clipping=self.tanh_clipping) - action_logprobs = gather_by_index(logprobs, td["action"], dim=1) dist_entropys = Categorical(logprobs.exp()).entropy() - return action_logprobs, value_pred, dist_entropys + if return_selected: + logprobs = gather_by_index(logprobs, td["action"], dim=1) + + return logprobs, value_pred, dist_entropys def act(self, td, env, phase: str = "train"): logits, mask = self.decoder(td, hidden=None, num_starts=0)