diff --git a/.gitignore b/.gitignore
new file mode 100755
index 0000000..bc581f0
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,7 @@
+/logs/
+/.idea/
+*.pyc
+*.csv
+logs
+log
+.vscode
diff --git a/README.md b/README.md
index 9ee3ece..a08824d 100644
--- a/README.md
+++ b/README.md
@@ -1,2 +1,81 @@
-# GOPT
-The code is under preparation and is coming soon.
+
+ GOPT: Generalizable Online 3D Bin Packing via Transformer-based Deep Reinforcement Learning
+
+RA-L 2024 (Accepted)
+
+
+
+
+
+If you have any questions, feel free to contact me by xiongheng@hust.edu.cn.
+
+## Introduction
+Robotic object packing has broad practical applications in the logistics and automation industry, often formulated by researchers as the online 3D Bin Packing Problem (3D-BPP). However, existing DRL-based methods primarily focus on enhancing performance in limited packing environments while neglecting the ability to generalize across multiple environments characterized by different bin dimensions. To this end, we propose GOPT, a generalizable online 3D Bin Packing approach via Transformer-based deep reinforcement learning (DRL). First, we design a Placement Generator module to yield finite subspaces as placement candidates and the representation of the bin. Second, we propose a Packing Transformer, which fuses the features of the items and bin, to identify the spatial correlation between the item to be packed and available sub-spaces within the bin. Coupling these two components enables GOPT’s ability to perform inference on bins of varying dimensions.
+
+![overview](./images/overview.png)
+
+
+## Installation
+This code has been tested on Ubuntu 20.04 with Cuda 12.1, Python3.9 and Pytorch 2.1.0.
+
+```
+git clone https://github.com/Xiong5Heng/GOPT.git
+cd GOPT
+
+conda create -n GOPT python=3.9
+conda activate GOPT
+
+# install pytorch
+conda install pytorch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 pytorch-cuda=12.1 -c pytorch -c nvidia
+
+# install other dependencies
+pip install -r requirements.txt
+```
+
+## Training
+The dataset is generated on the fly, so you can directly train the model by running the following command.
+
+```bash
+python train.py --config cfg/config.yaml --device 0
+```
+
+If you do not use the default dataset (the bin is 10x10x10), you can modify the tag `env` in `cfg/config.yaml` file to specify the bin size and the number of items.
+Note that most hyperparameters are in the `cfg/config.yaml` file, you can modify them to fit your needs.
+
+
+## Evaluation
+
+```bash
+python eval.py --config cfg/config.yaml --device 0 --ckp /path/to/checkpoint.pth
+```
+
+If you want to visualize the packing process, you can add the `--render` flag.
+```bash
+python eval.py --config cfg/config.yaml --device 0 --ckp /path/to/checkpoint.pth --render
+```
+
+## Demo
+
+
+
+
+
+## Citation
+If you find this work useful, please consider citing:
+```
+@article{xiong2024gopt,
+ title={GOPT: Generalizable Online 3D Bin Packing via Transformer-based Deep Reinforcement Learning},
+ author={Xiong, Heng and Guo, Changrong and Peng, Jian and Ding, Kai and Chen, Wenjie and Qiu, Xuchong and Bai, Long and Xu, Jianfeng},
+ journal={IEEE Robotics and Automation Letters},
+ year={2024},
+ publisher={IEEE}
+}
+```
+
+## License
+This source code is released only for academic use. Please do not use it for commercial purposes without authorization of the author.
diff --git a/arguments.py b/arguments.py
new file mode 100644
index 0000000..7d5e45c
--- /dev/null
+++ b/arguments.py
@@ -0,0 +1,59 @@
+import os
+curr_path = os.path.dirname(os.path.abspath(__file__))
+
+import argparse
+
+from omegaconf import OmegaConf
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--config', type=str, default="cfg/config.yaml")
+ parser.add_argument('--ckp', type=str, default=None,
+ help="Path to the model to be tested")
+ parser.add_argument('--no-cuda', action='store_true',
+ help='Cuda will be enabled by default')
+ parser.add_argument('--device', type=int, default=0,
+ help='Which GPU will be called')
+ parser.add_argument('--test-episode', type=int, default=1000,
+ help='Number of episodes for evaluation')
+ parser.add_argument('--render', action='store_true',
+ help='Render the environment while testing')
+
+ args = parser.parse_args()
+
+ try:
+ args.config = os.path.join(curr_path, args.config)
+ cfg = OmegaConf.load(args.config)
+ except FileNotFoundError:
+ print("No configuration file found")
+
+ box_small = int(max(cfg.env.container_size) / 10)
+ box_big = int(max(cfg.env.container_size) / 2)
+ # box_range = (5, 5, 5, 25, 25, 25)
+ box_range = (box_small, box_small, box_small, box_big, box_big, box_big)
+
+ if cfg.get("env.step") is not None:
+ step = cfg.env.step
+ else:
+ step = box_small
+
+ box_size_set = []
+ for i in range(box_range[0], box_range[3] + 1, step):
+ for j in range(box_range[1], box_range[4] + 1, step):
+ for k in range(box_range[2], box_range[5] + 1, step):
+ box_size_set.append((i, j, k))
+
+ cfg.env.box_small = box_small
+ cfg.env.box_big = box_big
+ cfg.env.box_size_set = box_size_set
+ cfg.cuda = not args.no_cuda
+
+ cfg = OmegaConf.merge(cfg, vars(args))
+
+ return cfg
+
+
+if __name__ == "__main__":
+ args = get_args()
+ print(args.train.reward_type)
\ No newline at end of file
diff --git a/cfg/config.yaml b/cfg/config.yaml
new file mode 100644
index 0000000..56f43c2
--- /dev/null
+++ b/cfg/config.yaml
@@ -0,0 +1,49 @@
+# configuration file for training
+
+seed: 5
+cuda_deterministic: True
+log_interval: 10 # How often to print training logs
+
+env:
+ id: OnlinePack-v1 # env name OnlinePack-v1, PCT-v0
+ scheme: EMS # the scheme of generating candidate map: heightmap, EP, FC
+ rot: True
+ box_type: random # random, cut
+ container_size: [10, 10, 10]
+ step:
+ k_placement: 80 # number of candidate placements
+
+
+train:
+ algo: PPO
+ clip_param: 0.3
+ num_processes: 2 # the number of subprogresses, if debug, set to 1
+ num_steps: 5
+ epoch: 1000
+ last_epoch: 200
+ batch_size: 128
+ step_per_epoch: 40000 # 2**15
+ repeat_per_collect: 1
+ gae_lambda: 0.96
+ reward_type: # optional: "terminal", None
+ gamma: 1 # discount factor for rewards (default: 1)
+
+opt: # optimizer
+ optimizer: Adam # optimizer: Adam, RMSprop
+ lr: 7e-5 # learning rate (RMSprop7e-4, 1e-6, Adam7e-5)
+ lr_decay: True # use a linear schedule on the learning rate
+ eps: 1e-5 # epsilon (default: 1e-5)
+ alpha: 0.99 # RMSprop alpha (default: 0.99)
+
+loss:
+ entropy: 0.001 # entropy term coefficient (default: 0.01)
+ value: 0.5 # value loss coefficient (default: 0.5)
+
+model:
+ padding_mask: False # padding mask
+ embed_dim: 128
+ heads: 1
+ num_layers: 3
+ forward_expansion: 2
+ dropout: 0
+
diff --git a/cfg/default.yaml b/cfg/default.yaml
new file mode 100644
index 0000000..9e2502f
--- /dev/null
+++ b/cfg/default.yaml
@@ -0,0 +1,46 @@
+# configuration file for training
+
+seed: 5
+cuda_deterministic: True
+log_interval: 10 # How often to print training logs
+
+env:
+ id: OnlinePack-v1 # env name
+ scheme: EMS # the scheme of generating candidate map: heightmap, EP, FC
+ rot: True
+ box_type: random
+ container_size: [10, 10, 10]
+ k_placement: 100 # number of candidate placements
+
+
+train:
+ algo: PPO
+ clip_param: 0.3
+ num_processes: 128
+ num_steps: 6
+ epoch: 500
+ batch_size: 128
+ step_per_epoch: 32768 # 2**15
+ repeat_per_collect: 1
+ gae_lambda: 0.96
+ reward_type: # optional: "terminal", None
+ gamma: 1 # discount factor for rewards (default: 1)
+
+opt: # optimizer
+ optimizer: Adam # optimizer: Adam, RMSprop
+ lr: 7e-5 # learning rate (RMSprop7e-4, 1e-6, Adam7e-5)
+ lr_decay: False # use a linear schedule on the learning rate
+ eps: 1e-5 # epsilon (default: 1e-5)
+ alpha: 0.99 # RMSprop alpha (default: 0.99)
+
+loss:
+ entropy: 0.001 # entropy term coefficient (default: 0.01)
+ value: 0.5 # value loss coefficient (default: 0.5)
+
+model:
+ embed_dim: 128
+ heads: 1
+ num_layers: 4
+ forward_expansion: 2
+ dropout: 0
+
diff --git a/envs/Packing/__init__.py b/envs/Packing/__init__.py
new file mode 100755
index 0000000..fccc8d3
--- /dev/null
+++ b/envs/Packing/__init__.py
@@ -0,0 +1,20 @@
+# -*- coding: utf-8 -*-
+"""
+-------------------------------------------------
+Project Name: Packing-RL
+File Name: __init__.py.py
+Author: XEH1SGH
+Create Date: 4/14/2022
+-------------------------------------------------
+"""
+# from gym.envs.registration import register
+# from packing_env import PackingGame
+#
+# register(
+# id='Pack-v0',
+# entry_point='problems.OnlinePacking:PackingGame',
+# )
+from .env import PackingEnv
+
+__version__ = "0.0.1"
+
diff --git a/envs/Packing/binCreator.py b/envs/Packing/binCreator.py
new file mode 100755
index 0000000..92cf1db
--- /dev/null
+++ b/envs/Packing/binCreator.py
@@ -0,0 +1,80 @@
+import numpy as np
+import copy
+import torch
+
+
+class BoxCreator(object):
+ def __init__(self):
+ self.box_list = [] # generated box list
+
+ def reset(self):
+ self.box_list.clear()
+
+ def generate_box_size(self, **kwargs):
+ pass
+
+ def preview(self, length):
+ """
+ :param length:
+ :return: list
+ """
+ while len(self.box_list) < length:
+ self.generate_box_size()
+ return copy.deepcopy(self.box_list[:length])
+
+ def drop_box(self):
+ assert len(self.box_list) >= 0
+ self.box_list.pop(0)
+
+
+class RandomBoxCreator(BoxCreator):
+ default_box_set = []
+ for i in range(4):
+ for j in range(4):
+ for k in range(4):
+ default_box_set.append((2 + i, 2 + j, 2 + k))
+
+ def __init__(self, box_size_set=None):
+ super().__init__()
+ self.box_set = box_size_set
+ if self.box_set is None:
+ self.box_set = RandomBoxCreator.default_box_set
+ # print(self.box_set)
+
+ def generate_box_size(self, **kwargs):
+ idx = np.random.randint(0, len(self.box_set))
+ self.box_list.append(self.box_set[idx])
+
+
+# load data
+class LoadBoxCreator(BoxCreator):
+ def __init__(self, data_name=None): # data url
+ super().__init__()
+ self.data_name = data_name
+ self.index = 0
+ self.box_index = 0
+ self.traj_nums = len(torch.load(self.data_name))
+ print("load data set successfully, data name: ", self.data_name)
+
+ def reset(self, index=None):
+ self.box_list.clear()
+ box_trajs = torch.load(self.data_name)
+ self.recorder = []
+ if index is None:
+ self.index += 1
+ else:
+ self.index = index
+ self.boxes = box_trajs[self.index]
+ self.box_index = 0
+ self.box_set = self.boxes
+ self.box_set.append([10, 10, 10])
+
+ def generate_box_size(self, **kwargs):
+ if self.box_index < len(self.box_set):
+ self.box_list.append(self.box_set[self.box_index])
+ self.recorder.append(self.box_set[self.box_index])
+ self.box_index += 1
+ else:
+ self.box_list.append((10, 10, 10))
+ self.recorder.append((10, 10, 10))
+ self.box_index += 1
diff --git a/envs/Packing/box.py b/envs/Packing/box.py
new file mode 100755
index 0000000..bf42095
--- /dev/null
+++ b/envs/Packing/box.py
@@ -0,0 +1,18 @@
+
+class Box(object):
+ def __init__(self, length, width, height, x, y, z):
+ # dimension(x, y, z) + position(lx, ly, lz)
+ self.size_x = length
+ self.size_y = width
+ self.size_z = height
+ self.pos_x = x
+ self.pos_y = y
+ self.pos_z = z
+
+ def standardize(self):
+ """
+
+ Returns:
+ tuple(size + position)
+ """
+ return tuple([self.size_x, self.size_y, self.size_z, self.pos_x, self.pos_y, self.pos_z])
diff --git a/envs/Packing/container.py b/envs/Packing/container.py
new file mode 100755
index 0000000..75739b8
--- /dev/null
+++ b/envs/Packing/container.py
@@ -0,0 +1,698 @@
+import copy
+import itertools
+from functools import reduce
+from typing import Any, Dict, Optional, Sequence, Tuple, Union
+
+import numpy as np
+from scipy.spatial import ConvexHull
+from matplotlib.path import Path
+
+from .ems import compute_ems
+from .utils import *
+from .box import Box
+
+
+class Container(object):
+ def __init__(self, length=10, width=10, height=10, rotation=True):
+ self.dimension = np.array([length, width, height])
+ self.heightmap = np.zeros(shape=(length, width), dtype=np.int32)
+ self.can_rotate = rotation
+ # packed box list
+ self.boxes = []
+ # record rotation information
+ self.rot_flags = []
+ self.height = height
+ self.candidates = [[0, 0, 0]]
+
+ def print_heightmap(self):
+ print("container heightmap: \n", self.heightmap)
+
+ def get_heightmap(self):
+ """
+ get the heightmap for the ideal situation
+ Returns:
+
+ """
+ plain = np.zeros(shape=self.dimension[:2], dtype=np.int32)
+ for box in self.boxes:
+ plain = self.update_heightmap(plain, box)
+ return plain
+
+ def update_heightmap_vision(self, vision):
+ """
+ TODO
+ Args:
+ vision:
+
+ Returns:
+
+ """
+ self.heightmap = vision
+
+ @property
+ def visible_area(self):
+ """
+ refer to leetcode 892: calculate surface area of 3D shape
+ Returns:
+
+ """
+ state_copy = copy.deepcopy(self.heightmap)
+
+ # cnt: the number of cube with dimension 1x1x1
+ # ov: the number of intersections
+ # num: The number of squares that are not zero, is used to calculate the area of the bottom surface
+ cnt, ov, num = 0, 0, 0
+ # The area of the cube in contact with the bin's perimeter
+ area_x_min, area_x_max, area_y_min, area_y_max = 0, 0, 0, 0
+ m, n = len(state_copy), len(state_copy[0]) # length, width
+
+ for i in range(0, m):
+ for j in range(0, n):
+ cnt += state_copy[i, j]
+
+ # up and down
+ if state_copy[i, j] > 1:
+ ov += state_copy[i, j] - 1
+ # Forward and backward direction (x direction)
+ if i > 0: # backward
+ ov += min(state_copy[i, j], state_copy[i - 1, j])
+ # Left and right side (y direction)
+ if j > 0: # backward
+ ov += min(state_copy[i, j], state_copy[i, j - 1])
+
+ # calculate the area of the bottom surface
+ if state_copy[i, j] > 0:
+ num += 1
+ # Calculate the area of contact with the side of the bin
+ if i == 0:
+ area_x_min += state_copy[i, j]
+ if i == m - 1:
+ area_x_max += state_copy[i, j]
+ if j == 0:
+ area_y_min += state_copy[i, j]
+ if j == n - 1:
+ area_y_max += state_copy[i, j]
+
+ # A cube with 6 faces, two adjacent cubes with a total of two faces close to each other, need to subtract
+ return cnt * 6 - ov * 2 - (num + area_x_min + area_x_max + area_y_min + area_y_max)
+
+ @property
+ def compactness(self):
+ """
+ TODO, need test
+ Returns:
+
+ """
+ # compute max occupied length along X and Y axis, and maximum height
+ height_max = np.max(self.heightmap)
+ nonzero_index = list(np.nonzero(self.heightmap))
+ if len(nonzero_index[0]) == 0 and len(nonzero_index[1]) == 0:
+ # no non_zero value
+ compactness_ratio = 0
+ return compactness_ratio
+ nonzero_x_max = np.max(nonzero_index[0]) + 1
+ nonzero_y_max = np.max(nonzero_index[1]) + 1
+ nonzero_x_min = np.min(nonzero_index[0])
+ nonzero_y_min = np.min(nonzero_index[1])
+
+ vo = reduce(lambda x, y: x + y, [box.size_x * box.size_y * box.size_z for box in self.boxes], 0.0)
+ compactness_ratio = vo / ((nonzero_x_max - nonzero_x_min) * (nonzero_y_max - nonzero_y_min) * height_max)
+
+ return compactness_ratio
+
+ @property
+ def height_var(self):
+ """ Variance of height
+
+ Returns:
+
+ """
+ height_var = eval(format(np.var(self.heightmap), '.2f'))
+ return height_var
+
+ @staticmethod
+ def update_heightmap(plain, box):
+ """
+ update heightmap
+ Args:
+ plain:
+ box:
+
+ Returns:
+
+ """
+ plain = copy.deepcopy(plain)
+ le = box.pos_x
+ ri = box.pos_x + box.size_x
+ up = box.pos_y
+ do = box.pos_y + box.size_y
+ max_h = np.max(plain[le:ri, up:do])
+ max_h = max(max_h, box.pos_z + box.size_z)
+ plain[le:ri, up:do] = max_h
+ return plain
+
+ def get_box_list(self):
+ vec = list()
+ for box in self.boxes:
+ vec += box.standardize()
+ return vec
+
+ def get_plain(self):
+ return copy.deepcopy(self.heightmap)
+
+ def get_action_space(self):
+ return self.dimension[0] * self.dimension[1]
+
+ def get_action_mask(self, next_box, scheme="heightmap"):
+ action_mask = np.zeros(shape=(self.dimension[0], self.dimension[1]), dtype=np.int32)
+
+ if scheme == "heightmap":
+ candidates_xy, extra_corner_xy = self.candidate_from_heightmap(next_box, self.can_rotate)
+
+ for xy in candidates_xy:
+ if self.check_box(next_box, xy) > -1:
+ action_mask[xy[0], xy[1]] = 1
+ for xy in extra_corner_xy[:3]:
+ if self.check_box(next_box, xy) > -1:
+ action_mask[xy[0], xy[1]] = 1
+
+ if self.can_rotate:
+ rotated_box = [next_box[1], next_box[0], next_box[2]]
+ action_mask_rot = np.zeros_like(action_mask)
+
+ for xy in candidates_xy:
+ if self.check_box(rotated_box, xy) > -1:
+ action_mask_rot[xy[0], xy[1]] = 1
+ for xy in extra_corner_xy[-3:]:
+ if self.check_box(rotated_box, xy) > -1:
+ action_mask_rot[xy[0], xy[1]] = 1
+
+ action_mask = np.hstack((action_mask.reshape((-1,)), action_mask_rot.reshape((-1,))))
+
+ elif scheme == "EP":
+ candidates_xy, extra_corner_xy = self.candidate_from_EP(next_box, self.can_rotate)
+ # extra_corner_xy = []
+ for xy in candidates_xy:
+ if self.check_box(next_box, xy) > -1:
+ action_mask[xy[0], xy[1]] = 1
+ for xy in extra_corner_xy[:3]:
+ if self.check_box(next_box, xy) > -1:
+ action_mask[xy[0], xy[1]] = 1
+
+ if self.can_rotate:
+ rotated_box = [next_box[1], next_box[0], next_box[2]]
+ action_mask_rot = np.zeros_like(action_mask)
+
+ for xy in candidates_xy:
+ if self.check_box(rotated_box, xy) > -1:
+ action_mask_rot[xy[0], xy[1]] = 1
+ for xy in extra_corner_xy[-3:]:
+ if self.check_box(rotated_box, xy) > -1:
+ action_mask_rot[xy[0], xy[1]] = 1
+
+ action_mask = np.hstack((action_mask.reshape((-1,)), action_mask_rot.reshape((-1,))))
+
+ elif scheme == "FC":
+ x_list = list(range(self.dimension[0]))
+ y_list = list(range(self.dimension[1]))
+ candidates_xy = list(itertools.product(x_list, y_list))
+
+ for xy in candidates_xy:
+ if self.check_box(next_box, xy) > -1:
+ action_mask[xy[0], xy[1]] = 1
+
+ if self.can_rotate:
+ rotated_box = [next_box[1], next_box[0], next_box[2]]
+ action_mask_rot = np.zeros_like(action_mask)
+
+ for xy in candidates_xy:
+ if self.check_box(rotated_box, xy) > -1:
+ action_mask_rot[xy[0], xy[1]] = 1
+
+ action_mask = np.hstack((action_mask.reshape((-1,)), action_mask_rot.reshape((-1,))))
+
+ # assert False, 'No FC implementation'
+ else:
+ assert False, 'Wrong candidate generation scheme'
+
+ # if all actions are invalid, set all mask is 1 and perform any action to end this episode
+ if action_mask.sum() == 0:
+ action_mask[:] = 1
+
+ return action_mask.reshape(-1).tolist()
+
+ def check_box(self, box_size, pos_xy, benchmark=False):
+ """
+ check
+ 1. whether cross the border
+ 2. check stability
+ Args:
+ box_size:
+ pos_xy:
+
+ Returns:
+
+ """
+ if pos_xy[0] + box_size[0] > self.dimension[0] or pos_xy[1] + box_size[1] > self.dimension[1]:
+ return -1
+
+ pos_z = np.max(self.heightmap[pos_xy[0]:pos_xy[0] + box_size[0], pos_xy[1]:pos_xy[1] + box_size[1]])
+
+ # whether cross the broder
+ if pos_z + box_size[2] > self.dimension[2]:
+ return -1
+
+ # check stability
+ if benchmark:
+ # zhao AAAI2021 paper
+ rec = self.heightmap[pos_xy[0]:pos_xy[0] + box_size[0], pos_xy[1]:pos_xy[1] + box_size[1]]
+ r00 = rec[0, 0]
+ r10 = rec[box_size[0] - 1, 0]
+ r01 = rec[0, box_size[1] - 1]
+ r11 = rec[box_size[0] - 1, box_size[1] - 1]
+ rm = max(r00, r10, r01, r11)
+ sc = int(r00 == rm) + int(r10 == rm) + int(r01 == rm) + int(r11 == rm)
+ # at least 3 support point
+ if sc < 3:
+ return -1
+ # check area and corner
+ max_area = np.sum(rec == pos_z)
+ area = box_size[0] * box_size[1]
+ #
+ if max_area / area > 0.95:
+ return pos_z
+ if rm == pos_z and sc == 3 and max_area/area > 0.85:
+ return pos_z
+ if rm == pos_z and sc == 4 and max_area/area > 0.50:
+ return pos_z
+ else:
+ if self.is_stable(box_size, [pos_xy[0], pos_xy[1], pos_z]):
+ return pos_z
+
+ return -1
+
+ def check_box_ems(self, box_size, ems, benchmark=False):
+ """
+ check
+ 1. whether cross the border
+ 2. check stability
+ Args:
+ box_size:
+ pos_xy:
+
+ Returns:
+
+ """
+ if ems[3] - ems[0] < box_size[0] or ems[4] - ems[1] < box_size[1] or ems[5] - ems[2] < box_size[2]:
+ return -1
+
+ if ems[0] + box_size[0] > self.dimension[0] or ems[1] + box_size[1] > self.dimension[1]:
+ return -1
+
+ pos_z = np.max(self.heightmap[ems[0]:ems[0] + box_size[0], ems[1]:ems[1] + box_size[1]])
+
+ # whether cross the broder
+ if pos_z + box_size[2] > self.dimension[2]:
+ return -1
+
+ # check stability
+ if self.is_stable(box_size, [ems[0], ems[1], pos_z]):
+ return pos_z
+
+ return -1
+
+ def is_stable(self, dimension, position) -> bool:
+ """
+ check stability for 3D packing
+ Args:
+ dimension:
+ position:
+
+ Returns:
+
+ """
+ def on_segment(P1, P2, Q):
+ if ((Q[0] - P1[0]) * (P2[1] - P1[1]) == (P2[0] - P1[0]) * (Q[1] - P1[1]) and
+ min(P1[0], P2[0]) <= Q[0] <= max(P1[0], P2[0]) and
+ min(P1[1], P2[1]) <= Q[1] <= max(P1[1], P2[1])):
+ return True
+ else:
+ return False
+
+ # item on the ground of the bin
+ if position[2] == 0:
+ return True
+
+ # calculate barycentric coordinates, -1 means coordinate indices start at zero
+ x_1 = position[0]
+ x_2 = x_1 + dimension[0] - 1
+ y_1 = position[1]
+ y_2 = y_1 + dimension[1] - 1
+ z = position[2] - 1
+ obj_center = ((x_1 + x_2) / 2, (y_1 + y_2) / 2)
+
+ # valid points right under this object
+ points = []
+ for x in range(x_1, x_2 + 1):
+ for y in range(y_1, y_2 + 1):
+ if self.heightmap[x][y] == (z + 1):
+ points.append([x, y])
+
+ # the support area is more than half of the bottom surface of the item
+ if len(points) > dimension[0] * dimension[1] * 0.5:
+ return True
+
+ if len(points) == 0 or len(points) == 1:
+ return False
+ elif len(points) == 2: # whether the center lies on the line of the two points
+ return on_segment(points[0], points[1], obj_center)
+ else:
+ # calculate the convex hull of the points
+ points = np.array(points)
+ try:
+ convex_hull = ConvexHull(points)
+ except:
+ # error means co-lines
+ start_p = min(points, key=lambda p: [p[0], p[1]])
+ end_p = max(points, key=lambda p: [p[0], p[1]])
+ return on_segment(start_p, end_p, obj_center)
+
+ hull_path = Path(points[convex_hull.vertices])
+
+ return hull_path.contains_point(obj_center)
+
+ def get_volume_ratio(self):
+ vo = reduce(lambda x, y: x + y, [box.size_x * box.size_y * box.size_z for box in self.boxes], 0.0)
+ mx = self.dimension[0] * self.dimension[1] * self.dimension[2]
+ ratio = vo / mx
+ assert ratio <= 1.0
+ return ratio
+
+ # 1d index -> 2d plain coordinate
+ def idx_to_position(self, idx):
+ """
+ TODO
+ Args:
+ idx:
+
+ Returns:
+
+ """
+ lx = idx // self.dimension[1]
+ ly = idx % self.dimension[1]
+ return lx, ly
+
+ def position_to_index(self, position):
+ assert len(position) == 2
+ assert position[0] >= 0 and position[1] >= 0
+ assert position[0] < self.dimension[0] and position[1] < self.dimension[1]
+ return position[0] * self.dimension[1] + position[1]
+
+ def place_box(self, box_size, pos, rot_flag):
+ """ place box in the position (index), then update heightmap
+ :param box_size:
+ :param idx:
+ :param rot_flag:
+ :return:
+ """
+ if not rot_flag:
+ size_x = box_size[0]
+ size_y = box_size[1]
+ else:
+ size_x = box_size[1]
+ size_y = box_size[0]
+ size_z = box_size[2]
+ plain = self.heightmap
+ new_h = self.check_box([size_x, size_y, size_z], [pos[0], pos[1]])
+ if new_h != -1:
+ self.boxes.append(Box(size_x, size_y, size_z, pos[0], pos[1], pos[2])) # record rotated box
+ self.rot_flags.append(rot_flag)
+ self.heightmap = self.update_heightmap(plain, self.boxes[-1])
+ self.height = max(self.height, pos[2] + size_z)
+ return True
+ return False
+
+ def candidate_from_heightmap(self, next_box, max_n) -> list:
+ """
+ get the x and y coordinates of candidates
+ Args:
+ next_box:
+ can_rotate:
+
+ Returns:
+
+ """
+ heightmap = copy.deepcopy(self.heightmap)
+
+ corner_list = []
+ # hm_diff: height differences of neighbor columns, padding 0 in the front
+ # x coordinate
+ # heightmap: [r0, r1, r2, r3, r4, r5, ..., rn]
+ # insert: [r0, r0, r1, r2, r3, r4, r5, ..., rn]
+ hm_diff_x = np.insert(heightmap, 0, heightmap[0, :], axis=0)
+ # delete: [r0, r0, r1, r2, r3, r4, ..., rn-1]
+ hm_diff_x = np.delete(hm_diff_x, len(hm_diff_x) - 1, axis=0)
+ # hm_diff_x: [0, r1-r0, r2-r1, r3-r2, r4-r3, r5-r4, rn-r(n-1)]
+ hm_diff_x = heightmap - hm_diff_x
+
+ # y coordinate
+ hm_diff_y = np.insert(heightmap, 0, heightmap[:, 0], axis=1)
+ hm_diff_y = np.delete(hm_diff_y, len(hm_diff_y.T) - 1, axis=1)
+ # hm_diff_y: [0, c1-c0, c2-c1, c3-c2, c4-c3, c5-c4, cn-c(n-1)]
+ hm_diff_y = heightmap - hm_diff_y
+
+ # get the xy coordinates of all left-deep-bottom corners
+ corner_x_list = np.array(np.nonzero(hm_diff_x)).T.tolist()
+ corner_y_list = np.array(np.nonzero(hm_diff_y)).T.tolist()
+
+ corner_xy_list = []
+ corner_xy_list.append([0, 0])
+
+ for xy in corner_x_list:
+ x, y = xy
+ if y != 0 and [x, y - 1] in corner_x_list:
+ # if heightmap[x, y] == heightmap[x, y - 1] and hm_diff_x[x, y] == hm_diff_x[x, y - 1]:
+ if heightmap[x, y] == heightmap[x, y - 1]:
+ continue
+ corner_xy_list.append(xy)
+ for xy in corner_y_list:
+ x, y = xy
+ if x != 0 and [x - 1, y] in corner_y_list:
+ # if heightmap[x, y] == heightmap[x - 1, y] and hm_diff_x[x, y] == hm_diff_x[x - 1, y]:
+ if heightmap[x, y] == heightmap[x - 1, y]:
+ continue
+ if xy not in corner_xy_list:
+ corner_xy_list.append(xy)
+
+ candidate_x, candidate_y = zip(*corner_xy_list)
+ # remove duplicate elements
+ candidate_x = list(set(candidate_x))
+ candidate_y = list(set(candidate_y))
+
+ # get corner_list
+ corner_list = list(itertools.product(candidate_x, candidate_y))
+ candidates = []
+
+ for xy in corner_list:
+ z = self.check_box(next_box, xy)
+ if z > -1:
+ # candidates.append([xy[0], xy[1], z, 0])
+ candidates.append([xy[0], xy[1], z, xy[0] + next_box[0], xy[1] + next_box[1], z + next_box[2]])
+
+ if self.can_rotate:
+ rotated_box = [next_box[1], next_box[0], next_box[2]]
+ for xy in corner_list:
+ z = self.check_box(rotated_box, xy)
+ if z > -1:
+ # candidates.append([xy[0], xy[1], z, 1])
+ candidates.append([xy[0], xy[1], z, xy[0] + rotated_box[0], xy[1] + rotated_box[1], z + rotated_box[2]])
+
+ # sort by z, y coordinate, then x
+ candidates.sort(key=lambda x: [x[2], x[1], x[0]])
+
+ if len(candidates) > max_n:
+ candidates = candidates[:max_n]
+ self.candidates = candidates
+ return np.array(candidates)
+
+ def candidate_from_EP(self, next_box, max_n) -> list:
+ """
+ calculate extreme points from items extracted from current heightmap
+ Args:
+ new_item:
+
+ Returns:
+
+ """
+ heightmap = copy.deepcopy(self.heightmap)
+ items_in = extract_items_from_heightmap(heightmap)
+ new_eps = []
+ new_eps.append([0, 0, 0])
+
+ for k in range(len(items_in)):
+ items_in_copy = copy.deepcopy(items_in)
+ item_new = items_in_copy[k]
+ new_dim = item_new[:3]
+ new_pos = item_new[-3:]
+
+ items_in_copy.pop(k)
+ item_fitted = items_in_copy
+
+ # add xoy, xoz, yoz planes for easy projection
+ item_fitted.append([self.dimension[0], self.dimension[1], 0, 0, 0, 0])
+ item_fitted.append([self.dimension[0], 0, self.dimension[2], 0, 0, 0])
+ item_fitted.append([0, self.dimension[1], self.dimension[2], 0, 0, 0])
+
+ max_bounds = [-1, -1, -1, -1, -1, -1]
+
+ for i in range(len(item_fitted)):
+ fitted_dim = item_fitted[i][:3]
+ fitted_pos = item_fitted[i][-3:]
+ project_x = fitted_dim[0] + fitted_pos[0]
+ project_y = fitted_dim[1] + fitted_pos[1]
+ project_z = fitted_dim[2] + fitted_pos[2]
+
+ # Xy - new_eps[0]
+ if can_take_projection(item_new, item_fitted[i], 0, 1) and project_y > max_bounds[Projection.Xy]:
+ new_eps.append([new_pos[0] + new_dim[0], project_y, new_pos[2]])
+ max_bounds[Projection.Xy] = project_y
+
+ # Xz - new_eps[1]
+ if can_take_projection(item_new, item_fitted[i], 0, 2) and project_z > max_bounds[Projection.Xz]:
+ new_eps.append([new_pos[0] + new_dim[0], new_pos[1], project_z])
+ max_bounds[Projection.Xz] = project_z
+
+ # Yx - new_eps[2]
+ if can_take_projection(item_new, item_fitted[i], 1, 0) and project_x > max_bounds[Projection.Yx]:
+ new_eps.append([project_x, new_pos[1] + new_dim[1], new_pos[2]])
+ max_bounds[Projection.Yx] = project_x
+
+ # Yz - new_eps[3]
+ if can_take_projection(item_new, item_fitted[i], 1, 2) and project_z > max_bounds[Projection.Yz]:
+ new_eps.append([new_pos[0], new_pos[1] + new_dim[1], project_z])
+ max_bounds[Projection.Yz] = project_z
+
+ # Zx - new_eps[4]
+ if can_take_projection(item_new, item_fitted[i], 2, 0) and project_x > max_bounds[Projection.Zx]:
+ new_eps.append([project_x, new_pos[1], new_pos[2] + new_dim[2]])
+ max_bounds[Projection.Zx] = project_x
+
+ # Zy - new_eps[5]
+ if can_take_projection(item_new, item_fitted[i], 2, 1) and project_y > max_bounds[Projection.Zy]:
+ new_eps.append([new_pos[0], project_y, new_pos[2] + new_dim[2]])
+ max_bounds[Projection.Zy] = project_y
+
+ new_eps = [ep for ep in new_eps if not (ep[0] == self.dimension[0] or
+ ep[1] == self.dimension[1] or
+ ep[2] == self.dimension[2])]
+
+ # only need x, y
+ new_eps = np.array(new_eps, dtype=np.int32)
+
+ # remove duplicates
+ new_eps = np.unique(new_eps, axis=0)
+ candidates = new_eps.tolist()
+ candidates.sort(key=lambda x: [x[2], x[1], x[0]])
+ mask = np.zeros((2, max_n), dtype=np.int8)
+
+ if len(candidates) > max_n:
+ candidates = candidates[:max_n]
+
+ for id, ep in enumerate(candidates):
+ z = self.check_box(next_box, ep)
+ if z > -1 and z == ep[2]:
+ mask[0, id] = 1
+ if self.can_rotate:
+ rotated_box = [next_box[1], next_box[0], next_box[2]]
+ for id, ep in enumerate(candidates):
+ z = self.check_box(rotated_box, ep)
+ if z > -1 and z == ep[2]:
+ mask[1, id] = 1
+
+ self.candidates = candidates
+ return np.array(candidates), mask
+
+ def candidate_from_EMS(self,
+ next_box,
+ max_n
+ ) -> Tuple[np.ndarray, np.ndarray]:
+ """ calculate Empty Maximum Space from items extracted from current heightmap
+
+ Args:
+ next_box (_type_): _description_
+ max_n (_type_): _description_
+
+ Returns:
+ list: _description_
+ """
+ heightmap = copy.deepcopy(self.heightmap)
+ # left-bottom & right-top pos [bx, by, bz, tx, ty, tz],
+ # dimension: [tx-bx, ty-by, tz-bz],
+ all_ems = compute_ems(heightmap, container_h=self.dimension[2])
+
+ candidates = all_ems
+ mask = np.zeros((2, max_n), dtype=np.int8)
+
+ # sort by z, y coordinate, then x
+ candidates.sort(key=lambda x: [x[2], x[1], x[0]])
+
+ if len(candidates) > max_n:
+ candidates = candidates[:max_n]
+
+ for id, ems in enumerate(candidates):
+ if self.check_box_ems(next_box, ems) > -1:
+ mask[0, id] = 1
+ if self.can_rotate:
+ rotated_box = [next_box[1], next_box[0], next_box[2]]
+ for id, ems in enumerate(candidates):
+ if self.check_box_ems(rotated_box, ems) > -1:
+ mask[1, id] = 1
+
+ self.candidates = candidates
+ return np.array(candidates), mask
+
+ def candidate_from_FC(self, next_box) -> list:
+ """
+ calculate extreme points from items extracted from current heightmap
+ Args:
+ new_item:
+
+ Returns:
+
+ """
+ candidates = []
+
+ for x in range(self.dimension[0]):
+ for y in range(self.dimension[1]):
+ candidates.append([x, y, self.heightmap[x, y]])
+
+ mask = np.zeros((2, self.dimension[0]*self.dimension[1]), dtype=np.int8)
+
+ for id, xyz in enumerate(candidates):
+ z = self.check_box(next_box, xyz)
+ if z > -1 and z == xyz[2]:
+ mask[0, id] = 1
+ if self.can_rotate:
+ rotated_box = [next_box[1], next_box[0], next_box[2]]
+ for id, xyz in enumerate(candidates):
+ z = self.check_box(rotated_box, xyz)
+ if z > -1 and z == xyz[2]:
+ mask[1, id] = 1
+
+ self.candidates = candidates
+ return np.array(candidates), mask
+
+
+if __name__ == '__main__':
+ container = Container(3, 4, 10)
+ container.heightmap = np.array([[5, 1, 4, 4],
+ [1, 5, 4, 1],
+ [4, 4, 4, 1]])
+ # container.print_heightmap()
+ # next = [3, 2, 2]
+ # mask = container.get_action_mask(next, True)
+
+ # print(mask.reshape((-1, 10, 8)))
+ print(container.place_box([3, 3, 3], [0, 0, 5], 0))
+
+ # print(container.candidate_from_EMS([2, 2, 2], 10))
diff --git a/envs/Packing/cutCreator.py b/envs/Packing/cutCreator.py
new file mode 100755
index 0000000..d41ff81
--- /dev/null
+++ b/envs/Packing/cutCreator.py
@@ -0,0 +1,152 @@
+import numpy as np
+import copy
+import random
+import torch
+from .binCreator import BoxCreator
+
+class MetaBox():
+ def __init__(self, size_x, size_y, size_z, pos_x, pos_y, pos_z):
+ self.size_x = size_x
+ self.size_y = size_y
+ self.size_z = size_z
+ self.pos_x = pos_x
+ self.pos_y = pos_y
+ self.pos_z = pos_z
+
+ def split(self, divide_flag, pos):
+ if divide_flag == 0:
+ box1 = MetaBox(pos, self.size_y, self.size_z, self.pos_x, self.pos_y, self.pos_z)
+ box2 = MetaBox(self.size_x - pos, self.size_y, self.size_z, self.pos_x + pos, self.pos_y, self.pos_z)
+ elif divide_flag == 1:
+ box1 = MetaBox(self.size_x, pos, self.size_z, self.pos_x, self.pos_y, self.pos_z)
+ box2 = MetaBox(self.size_x, self.size_y - pos, self.size_z, self.pos_x, self.pos_y + pos, self.pos_z)
+ elif divide_flag == 2:
+ box1 = MetaBox(self.size_x, self.size_y, pos, self.pos_x, self.pos_y, self.pos_z)
+ box2 = MetaBox(self.size_x, self.size_y, self.size_z - pos, self.pos_x, self.pos_y, self.pos_z + pos)
+ return box1, box2
+
+ def __str__(self):
+ return '(%d, %d, %d, %d, %d, %d)' % (self.size_x, self.size_y, self.size_z, self.pos_x, self.pos_y, self.pos_z)
+
+
+class CuttingBoxCreator(BoxCreator):
+ def __init__(self, bin_size, box_range, rotation=False):
+ super().__init__()
+ self.box_list = []
+ self.bin_size = bin_size
+ self.box_range = box_range
+ self.rotation = rotation
+
+ self.plain = np.zeros(shape=(self.bin_size[0], self.bin_size[1]), dtype=np.int32)
+ self.meta_list = [MetaBox(*self.bin_size, 0, 0, 0)]
+ self.candidates = []
+ self._cut_box(*self.box_range)
+ self._add_candidate()
+
+ def reset(self):
+ self.box_list.clear()
+ self.plain = np.zeros(shape=(self.bin_size[0], self.bin_size[1]), dtype=np.int32)
+ self.meta_list = [MetaBox(*self.bin_size, 0, 0, 0)]
+ self.candidates = []
+ self._cut_box(*self.box_range)
+ self._add_candidate()
+
+ def _check_box(self, box, low_x, low_y, low_z, high_x, high_y, high_z):
+ x_flag = box.size_x < low_x or box.size_x > high_x
+ y_flag = box.size_y < low_y or box.size_y > high_y
+ z_flag = box.size_z < low_z or box.size_z > high_z
+ return x_flag * 1 + y_flag * 2 + z_flag * 4
+
+ def _choose_pos(self, box, check, low_x, low_y, low_z, high_x, high_y, high_z):
+ df_list = []
+ if 1 & check:
+ df_list.append(0)
+ if 2 & check:
+ df_list.append(1)
+ if 4 & check:
+ df_list.append(2)
+ df = random.choice(df_list)
+ if df == 0:
+ pos_range = (low_x, box.size_x - low_x)
+ if df == 1:
+ pos_range = (low_y, box.size_y - low_y)
+ if df == 2:
+ pos_range = (low_z, box.size_z - low_z)
+ assert pos_range[0] <= pos_range[1]
+ pos = random.randint(pos_range[0], pos_range[1])
+ return df, pos
+
+ def _cut_box(self, low_x, low_y, low_z, high_x, high_y, high_z):
+ continue_flag = True
+ new_list = []
+ while continue_flag:
+ continue_flag = False
+ for box in self.meta_list:
+ check = self._check_box(box, low_x, low_y, low_z, high_x, high_y, high_z)
+ if check == 0:
+ new_list.append(box)
+ else:
+ df, pos = self._choose_pos(box, check, low_x, low_y, low_z, high_x, high_y, high_z)
+ box1, box2 = box.split(df, pos)
+ new_list.append(box1)
+ new_list.append(box2)
+ continue_flag = True
+ self.meta_list = copy.deepcopy(new_list)
+ new_list.clear()
+ # print('total box num: ', len(self.meta_list))
+
+ def _add_candidate(self):
+ new_list = []
+ for i in range(len(self.meta_list)):
+ mb = self.meta_list[i]
+ check = (self.plain[mb.pos_x:mb.pos_x + mb.size_x, mb.pos_y:mb.pos_y + mb.size_y] == mb.pos_z).sum() - mb.size_x * mb.size_y
+ if check == 0:
+ self.candidates.append(mb)
+ else:
+ new_list.append(mb)
+ self.meta_list = new_list
+
+ def _update(self, box):
+ self.plain[box.pos_x:box.pos_x + box.size_x, box.pos_y:box.pos_y + box.size_y] += box.size_z
+
+ def generate_box_size(self, **kwargs):
+ if len(self.candidates) == 0:
+ self.box_list.append(self.bin_size)
+ return
+ idx = random.randint(0, len(self.candidates) - 1)
+ box = self.candidates.pop(idx)
+ if not self.rotation:
+ self.box_list.append((box.size_x, box.size_y, box.size_z))
+ else:
+ rd = np.random.rand()
+ # randomly rotate boxes
+ if rd < 0.5:
+ self.box_list.append((box.size_x, box.size_y, box.size_z))
+ else:
+ self.box_list.append((box.size_y, box.size_x, box.size_z))
+ self._update(box)
+ self._add_candidate()
+
+class LoadBoxCreator(BoxCreator):
+ def __init__(self, data_name = None):
+ super().__init__()
+ self.box_trajs = torch.load(data_name)
+ print("load data set successfully!")
+ self.index = 0
+ self.box_index = 0
+ self.traj_nums = len(self.box_trajs)
+
+ def reset(self):
+ self.box_list.clear()
+ self.boxes = self.box_trajs[self.index]
+ self.recorder = []
+ self.index += 1
+ self.box_index = 0
+ self.box_set = self.boxes
+ self.box_set.append([10, 10, 10])
+
+ def generate_box_size(self, **kwargs):
+ self.box_list.append(self.box_set[self.box_index])
+ self.recorder.append(self.box_set[self.box_index])
+ self.box_index += 1
+
diff --git a/envs/Packing/ems.py b/envs/Packing/ems.py
new file mode 100644
index 0000000..f81e015
--- /dev/null
+++ b/envs/Packing/ems.py
@@ -0,0 +1,240 @@
+
+import copy
+import itertools
+import time
+
+import numpy as np
+
+
+def compute_corners(heightmap: np.ndarray):
+ # NOTE find corners by heightmap
+
+ hm_shape = heightmap.shape
+ extend_hm = np.ones((hm_shape[0]+2, hm_shape[1]+2)) * 10000
+ extend_hm[1:-1, 1:-1] = heightmap
+
+ x_diff_hm_1 = extend_hm[:-1] - extend_hm[1:]
+ x_diff_hm_1 = x_diff_hm_1[:-1, 1:-1]
+
+ x_diff_hm_2 = extend_hm[1:] - extend_hm[:-1]
+ x_diff_hm_2 = x_diff_hm_2[1:, 1:-1]
+
+ y_diff_hm_1 = extend_hm[:, :-1] - extend_hm[:, 1:]
+ y_diff_hm_1 = y_diff_hm_1[1:-1, :-1]
+
+ y_diff_hm_2 = extend_hm[:, 1:] - extend_hm[:, :-1]
+ y_diff_hm_2 = y_diff_hm_2[1:-1, 1:]
+
+ x_diff_hms = [x_diff_hm_1 != 0, x_diff_hm_2 != 0]
+ y_diff_hms = [y_diff_hm_1 != 0, y_diff_hm_2 != 0]
+
+ corner_hm = np.zeros_like(heightmap)
+ for xhm in x_diff_hms:
+ for yhm in y_diff_hms:
+ corner_hm += xhm * yhm
+
+ left_bottom_hm = (x_diff_hm_1 != 0) * (y_diff_hm_1 != 0)
+
+ left_bottom_corners = np.where(left_bottom_hm > 0)
+ left_bottom_corners = np.array(left_bottom_corners).transpose()
+
+ corners = np.where(corner_hm > 0)
+ corners = np.array(corners).transpose()
+
+ # x_borders = list(np.where(x_diff_hm_1.sum(axis=1))[0])
+ # y_borders = list(np.where(y_diff_hm_1.sum(axis=0))[0])
+ x_borders = list(np.unique(np.where(x_diff_hm_1 != 0)[0]))
+ y_borders = list(np.unique(np.where(y_diff_hm_1 != 0)[0]))
+
+ x_borders.append(hm_shape[0])
+ y_borders.append(hm_shape[1])
+
+ return corners, left_bottom_corners, x_borders, y_borders
+
+
+def compute_stair_corners(heightmap, corners):
+
+ corners, _, _, _ = compute_corners(heightmap)
+
+ stair_hm = np.zeros_like(heightmap)
+ corner_heights = heightmap[corners[:,0], corners[:,1]]
+ sort_ids = np.argsort(corner_heights)
+ sort_corners = corners[sort_ids]
+
+ for c in sort_corners:
+ cx, cy = c
+ h = heightmap[cx, cy]
+ stair_hm[:cx+1, :cy+1] = h
+
+ _, slb_corner, _, _ = compute_corners(stair_hm)
+ return slb_corner
+
+
+def compute_empty_space(
+ container_h,
+ corners,
+ x_borders,
+ y_borders,
+ heightmap,
+ empty_space_list,
+ x_side='left-right',
+ y_side='left-right',
+ min_ems_width=0,
+ container_id=0
+ ):
+ # NOTE find ems from corners
+ # EMS: [ [bx,by,bz], [tx,ty,tz], [i,i,i] ]
+ # 1. left-bottom pos [bx, by, bz]
+ # 2. right-top pos: [tx, ty, tz]
+ # 3. container_id: [i, i, i]
+
+ def check_valid_height_layer(height_layer):
+ return (height_layer <= 0).all()
+
+ for corner in corners:
+ x,y = corner
+ # h = int(heightmap[x, y])
+ h = heightmap[x, y]
+ if h == container_h: continue
+
+ h_layer = heightmap - h
+
+ for axes in itertools.permutations(range(2), 2):
+ x_small = x
+ x_large = x+1
+
+ y_small = y
+ y_large = y+1
+
+ for axis in axes:
+ if axis == 0:
+ if 'left' in x_side:
+ for xb in x_borders:
+ if x_small > xb:
+ # if (h_layer[xb:x, y_small:y_large] <= 0).all():
+ if check_valid_height_layer(h_layer[xb:x, y_small:y_large]):
+ x_small = xb
+ else: break
+
+ if 'right' in x_side:
+ for xb in x_borders[::-1]:
+ if x_large < xb:
+ if check_valid_height_layer(h_layer[x:xb, y_small:y_large]):
+ # if (h_layer[x:xb, y_small:y_large] <= 0).all():
+ x_large = xb
+ else: break
+
+ elif axis == 1:
+ if 'left' in y_side:
+ for yb in y_borders:
+ if y_small > yb:
+ if check_valid_height_layer(h_layer[ x_small:x_large, yb:y]):
+ # if (h_layer[ x_small:x_large, yb:y] <= 0).all():
+ y_small = yb
+ else: break
+
+ if 'right' in y_side:
+ for yb in y_borders[::-1]:
+ if y_large < yb:
+ if check_valid_height_layer(h_layer[ x_small:x_large, y:yb]):
+ # if (h_layer[ x_small:x_large, y:yb] <= 0).all():
+ y_large = yb
+ else: break
+
+ # if (h_layer[ x_small:x_large, y_small:y_large] <= 0).all():
+ if check_valid_height_layer(h_layer[x_small:x_large, y_small:y_large]):
+
+ # new_ems = [[x_small, y_small, h], [x_large, y_large, container_h],[container_id]*3 ]
+ new_ems = [x_small, y_small, h, x_large, y_large, container_h]
+
+ if (x_large - x_small <= 0) or (y_large - y_small <= 0) :
+ new_ems = None
+
+ # NOTE remove small ems
+ if min_ems_width > 0:
+ if x_large - x_small < min_ems_width or y_large - y_small < min_ems_width:
+ new_ems = None
+
+ if new_ems is not None and new_ems not in empty_space_list:
+ empty_space_list.append(new_ems)
+
+def compute_ems(
+ heightmap: np.ndarray,
+ container_h: int,
+ min_ems_width: int = 0,
+ id_map: np.ndarray = None
+ ) -> list:
+ container_h = int(container_h)
+ empty_max_spaces = []
+
+ if id_map is not None:
+ m = id_map
+ else:
+ m = heightmap
+ corners, left_bottom_corners, x_borders, y_borders = compute_corners(m)
+
+ compute_empty_space(
+ container_h,
+ left_bottom_corners,
+ x_borders,
+ y_borders,
+ heightmap,
+ empty_max_spaces,
+ 'right',
+ 'right',
+ min_ems_width=min_ems_width
+ )
+
+ compute_empty_space(
+ container_h,
+ corners,
+ x_borders,
+ y_borders,
+ heightmap,
+ empty_max_spaces,
+ 'left-right',
+ 'left-right',
+ min_ems_width=min_ems_width
+ )
+
+ # NOTE stair corners
+ stair_corners = compute_stair_corners(heightmap, corners)
+ compute_empty_space(
+ container_h,
+ stair_corners,
+ x_borders,
+ y_borders,
+ heightmap,
+ empty_max_spaces,
+ 'right',
+ 'right',
+ min_ems_width=min_ems_width
+ )
+
+ return empty_max_spaces
+
+
+def add_box(heightmap, box, pos):
+ bx, by, bz = box
+ px, py, pz = pos
+
+ z = heightmap[px: px+bx, py:py+by].max()
+ heightmap[px: px+bx, py:py+by] = z + bz
+
+
+if __name__ == '__main__':
+ length = 10
+ h = np.zeros([length, length])
+
+ # add_box(h, [2,2,1], [0,0,0])
+ # add_box(h, [2,2,3], [2,3,0])
+ # add_box(h, [2,6,3], [7,3,0])
+ # add_box(h, [4,6,7], [0,3,0])
+ # add_box(h, [4,6,1], [3,0,0])
+ # add_box(h, [4,2,2], [5,2,0])
+ add_box(h, [9,9,9], [0,0,0])
+ print(h)
+ all_ems = compute_ems(h, length)
+ # all_ems = compute_ems(np.array(state), 30)
+ for ems in all_ems:
+ print(ems)
diff --git a/envs/Packing/env.py b/envs/Packing/env.py
new file mode 100755
index 0000000..5b5fbf5
--- /dev/null
+++ b/envs/Packing/env.py
@@ -0,0 +1,211 @@
+from typing import Optional
+
+from .container import Container
+import numpy as np
+import gymnasium as gym
+from gymnasium import spaces
+from .cutCreator import CuttingBoxCreator
+# from .mdCreator import MDlayerBoxCreator
+from .binCreator import RandomBoxCreator, LoadBoxCreator, BoxCreator
+
+from render import VTKRender
+
+
+class PackingEnv(gym.Env):
+ def __init__(
+ self,
+ container_size=(10, 10, 10),
+ item_set=None,
+ data_name=None,
+ load_test_data=False,
+ enable_rotation=False,
+ data_type="random",
+ reward_type=None,
+ action_scheme="heightmap",
+ k_placement=100,
+ is_render=False,
+ is_hold_on=False,
+ **kwags
+ ) -> None:
+ self.bin_size = container_size
+ self.area = int(self.bin_size[0] * self.bin_size[1])
+ # packing state
+ self.container = Container(*self.bin_size, rotation=enable_rotation)
+ self.can_rotate = enable_rotation
+ self.reward_type = reward_type
+ self.action_scheme = action_scheme
+ self.k_placement = k_placement
+ if action_scheme == "EMS":
+ self.candidates = np.zeros((self.k_placement, 6), dtype=np.int32) # (x1, y1, z1, x2, y2, H)
+ else:
+ self.candidates = np.zeros((self.k_placement, 3), dtype=np.int32) # (x, y, z)
+
+ # Generator for train/test data
+ if not load_test_data:
+ assert item_set is not None
+ if data_type == "random":
+ print(f"using items generated randomly")
+ self.box_creator = RandomBoxCreator(item_set)
+ if data_type == "cut":
+ print(f"using items generated through cutting method")
+ low = list(item_set[0])
+ up = list(item_set[-1])
+ low.extend(up)
+ self.box_creator = CuttingBoxCreator(container_size, low, self.can_rotate)
+ assert isinstance(self.box_creator, BoxCreator)
+ if load_test_data:
+ print(f"use box dataset: {data_name}")
+ self.box_creator = LoadBoxCreator(data_name)
+
+ self.test = load_test_data
+
+ # for rendering
+ if is_render:
+ self.renderer = VTKRender(container_size, auto_render=not is_hold_on)
+ self.render_box = None
+
+ self._set_space()
+
+ def _set_space(self) -> None:
+ obs_len = self.area + 3 # the state of bin + the dimension of box (l, w, h)
+ obs_len += self.k_placement * 6
+ self.action_space = spaces.Discrete(self.k_placement)
+ self.observation_space = spaces.Dict(
+ {
+ "obs": spaces.Box(low=0, high=max(self.bin_size), shape=(obs_len, )),
+ "mask": spaces.Discrete(self.k_placement)
+ }
+ )
+
+ def get_box_ratio(self):
+ coming_box = self.next_box
+ return (coming_box[0] * coming_box[1] * coming_box[2]) / (
+ self.container.dimension[0] * self.container.dimension[1] * self.container.dimension[2])
+
+ # box mask (W x L x 3)
+ def get_box_plain(self):
+ coming_box = self.next_box
+ x_plain = np.ones(self.container.dimension[:2], dtype=np.int32) * coming_box[0]
+ y_plain = np.ones(self.container.dimension[:2], dtype=np.int32) * coming_box[1]
+ z_plain = np.ones(self.container.dimension[:2], dtype=np.int32) * coming_box[2]
+ return x_plain, y_plain, z_plain
+
+ @property
+ def cur_observation(self):
+ """
+ get current observation and action mask
+ """
+ hmap = self.container.heightmap
+ size = list(self.next_box)
+ placements, mask = self.get_possible_position(size)
+ self.candidates = np.zeros_like(self.candidates)
+ if len(placements) != 0:
+ # print("candidates:")
+ # for c in placements:
+ # print(c)
+ self.candidates[0:len(placements)] = placements
+
+ size.extend([size[1], size[0], size[2]])
+ obs = np.concatenate((hmap.reshape(-1), np.array(size).reshape(-1), self.candidates.reshape(-1)))
+ mask = mask.reshape(-1)
+ return {
+ "obs": obs,
+ "mask": mask
+ }
+
+ @property
+ def next_box(self) -> list:
+ return self.box_creator.preview(1)[0]
+
+ def get_possible_position(self, next_box):
+ """
+ get possible actions for next box
+ Args:
+ scheme: the scheme how to generate candidates
+
+ Returns:
+ candidate action mask, i.e., the position where the current item should be placed
+ """
+ if self.action_scheme == "heightmap":
+ candidates = self.container.candidate_from_heightmap(next_box, self.k_placement)
+ elif self.action_scheme == "EP":
+ candidates, mask = self.container.candidate_from_EP(next_box, self.k_placement)
+ elif self.action_scheme == "EMS":
+ candidates, mask = self.container.candidate_from_EMS(next_box, self.k_placement)
+ elif self.action_scheme == "FC": # full coordinate space
+ candidates, mask = self.container.candidate_from_FC(next_box)
+ else:
+ raise NotImplementedError("action scheme not implemented")
+
+ return candidates, mask
+
+ def idx2pos(self, idx):
+ if idx >= self.k_placement - 1:
+ idx = idx - self.k_placement
+ rot = 1
+ else:
+ rot = 0
+
+ pos = self.candidates[idx][:3]
+
+ if rot == 1:
+ dim = [self.next_box[1], self.next_box[0], self.next_box[2]]
+ else:
+ dim = list(self.next_box)
+ self.render_box = [dim, pos]
+
+ return pos, rot, dim
+
+ def step(self, action):
+ """
+
+ :param action: action index
+ :return: cur_observation
+ reward
+ done, Whether to end boxing (i.e., the current box cannot fit in the bin)
+ info
+ """
+ # print(self.next_box)
+ pos, rot, size = self.idx2pos(action)
+
+ succeeded = self.container.place_box(self.next_box, pos, rot)
+
+ if not succeeded:
+ if self.reward_type == "terminal": # Terminal reward
+ reward = self.container.get_volume_ratio()
+ else: # Step-wise/Immediate reward
+ reward = 0.0
+ done = True
+
+ self.render_box = [[0, 0, 0], [0, 0, 0]]
+ info = {'counter': len(self.container.boxes), 'ratio': self.container.get_volume_ratio()}
+ return self.cur_observation, reward, done, False, info
+
+ box_ratio = self.get_box_ratio()
+
+ self.box_creator.drop_box() # remove current box from the list
+ self.box_creator.generate_box_size() # add a new box to the list
+
+ if self.reward_type == "terminal":
+ reward = 0.01
+ else:
+ reward = box_ratio
+ done = False
+ info = {'counter': len(self.container.boxes), 'ratio': self.container.get_volume_ratio()}
+
+ return self.cur_observation, reward, done, False, info
+
+ def reset(self, seed: Optional[int] = None):
+ super().reset(seed=seed)
+ self.box_creator.reset()
+ self.container = Container(*self.bin_size)
+ self.box_creator.generate_box_size()
+ self.candidates = np.zeros_like(self.candidates)
+ return self.cur_observation, {}
+
+ def seed(self, s=None):
+ np.random.seed(s)
+
+ def render(self):
+ self.renderer.add_item(self.render_box[0], self.render_box[1])
+ # self.renderer.save_img()
diff --git a/envs/Packing/mdCreator.py b/envs/Packing/mdCreator.py
new file mode 100755
index 0000000..e409c1d
--- /dev/null
+++ b/envs/Packing/mdCreator.py
@@ -0,0 +1,175 @@
+import numpy as np
+import copy
+import transforms3d.euler as eu
+import random
+import torch
+from .binCreator import BoxCreator
+
+
+class Box(object):
+ def __init__(self, given_bound, high, low, x, y, z, lx, ly, lz):
+ self.high_bound = high
+ self.low_bound = low
+ self.x = x
+ self.y = y
+ self.z = z
+ self.volume = x * y * z
+ self.location = np.array([lx, ly, lz])
+ self.extent = np.array([x / 2, y / 2, z / 2])
+ self.rotation = np.array([1, 0, 0, 0])
+ self.Rot = eu.quat2euler(self.rotation)
+ self.vertex = np.zeros((8, 3))
+ self.getCorners(self.extent * 2, self.location)
+ self.centre = (self.vertex[7] - self.vertex[0]) / 2 + self.location
+ self.x_flag = False
+ self.y_flag = False
+ self.given_bound = given_bound
+
+ def getCorners(self, size, location, quaternion=np.array([1, 0, 0, 0])):
+ for i in range(2):
+ for j in range(2):
+ for k in range(2):
+ self.vertex[i * 4 + j * 2 + k] = np.array(
+ [location[0] + k * size[0], location[1] + j * size[1], location[2] + i * size[2]])
+ R = eu.quat2mat(quaternion)
+ vertex = np.array(self.vertex, np.float32)
+ return np.dot(R, vertex.transpose())
+
+ def rotate_box(self, quaternion):
+ vertex = []
+ if isinstance(quaternion, np.ndarray):
+ if np.shape(quaternion) == (4, 1):
+ self.rotation = quaternion.transpose()
+ elif np.shape(quaternion) == (1, 4):
+ self.rotation = quaternion
+ elif isinstance(quaternion, list):
+ self.rotation = np.array(quaternion)
+ for i in range(2):
+ for j in range(2):
+ for k in range(2):
+ vertex.append([k * self.x, j * self.y, i * self.z])
+ # (0 0 0)(1 0 0)(0 1 0)(1 1 0)(0 0 1)(1 0 1)(0 1 1)(1 1 1)
+ self.R = eu.quat2mat(quaternion)
+ vertex = np.array(vertex, np.float32)
+ vertex = np.dot(self.R, vertex.transpose()).transpose()
+ for s in vertex:
+ s += self.location
+ self.vertex = vertex
+ self.centre = (self.vertex[7] - self.vertex[0]) / 2 + self.location
+
+ def benchmark_split(self):
+ flags = []
+ if self.x > self.given_bound[1]:
+ flags.append(0)
+ if self.y > self.given_bound[1]:
+ flags.append(1)
+ if self.z > self.given_bound[1]:
+ flags.append(2)
+
+ divide_flag = random.choice(flags)
+
+ if divide_flag == 0:
+ if self.x <= self.given_bound[0]:
+ return False
+ rand_x = random.randint(1, self.x)
+ if rand_x < self.given_bound[0] or self.x - rand_x < self.given_bound[0]:
+ return False
+ box1 = Box(self.given_bound, self.high_bound, self.low_bound, rand_x, self.y, self.z, self.location[0],
+ self.location[1],
+ self.location[2])
+ box2 = Box(self.given_bound, self.high_bound, self.low_bound, self.x - rand_x, self.y, self.z,
+ self.location[0] + rand_x,
+ self.location[1], self.location[2])
+ elif divide_flag == 1:
+ if self.y < self.given_bound[0]:
+ return False
+ rand_y = random.randint(1, self.y)
+ if rand_y < self.given_bound[0] or self.y - rand_y < self.given_bound[0]:
+ return False
+ box1 = Box(self.given_bound, self.high_bound, self.low_bound, self.x, rand_y, self.z, self.location[0],
+ self.location[1],
+ self.location[2])
+ box2 = Box(self.given_bound, self.high_bound, self.low_bound, self.x, self.y - rand_y, self.z,
+ self.location[0],
+ self.location[1] + rand_y, self.location[2])
+ else:
+ if self.z < self.given_bound[0]:
+ return False
+ rand_z = random.randint(1, self.z)
+ if rand_z < self.given_bound[0] or self.z - rand_z < self.given_bound[0]:
+ return False
+ box1 = Box(self.given_bound, self.high_bound - rand_z, self.low_bound, self.x, self.y, self.z - rand_z,
+ self.location[0],
+ self.location[1], self.low_bound)
+ box2 = Box(self.given_bound, self.high_bound, self.high_bound - rand_z, self.x, self.y, rand_z,
+ self.location[0],
+ self.location[1], self.high_bound - rand_z)
+ return box1, box2
+
+
+class bin(object):
+ def __init__(self, container_size, given_bound):
+ self.bin = Box(given_bound, container_size[2], 0, container_size[0], container_size[1], container_size[2], 0, 0,
+ 0)
+ self.boxes = [self.bin]
+ self.given_bound = given_bound
+ print(self.given_bound)
+
+ def is_valid(self, box):
+ if self.given_bound[1] >= box.size_x >= self.given_bound[0] and \
+ self.given_bound[1] >= box.size_y >= self.given_bound[0] \
+ and self.given_bound[1] >= box.size_z >= self.given_bound[0]:
+ return True
+ return False
+
+ def gen_benchmark(self):
+ vaild_box = []
+ invalid_box = copy.deepcopy(self.boxes)
+ while True:
+ for box in invalid_box:
+ divide_boxes = box.benchmark_split()
+ if not isinstance(divide_boxes, bool):
+ invalid_box.remove(box)
+ for sub_box in divide_boxes:
+ if self.is_valid(sub_box):
+ vaild_box.append(sub_box)
+ else:
+ invalid_box.append(sub_box)
+ if len(invalid_box) == 0:
+ self.boxes = vaild_box
+ if len(vaild_box) <= 10:
+ print("???")
+ break
+ return True
+
+ def depart_box(self):
+ self.boxes.sort(key=lambda x: x.low_bound, reverse=False)
+
+ def reset(self):
+ while True:
+ self.boxes = [self.bin]
+ if self.gen_benchmark():
+ break
+ self.depart_box()
+
+
+class MDlayerBoxCreator(BoxCreator):
+ def __init__(self, container_size, given_bound):
+ super().__init__()
+ self.container = bin(container_size, given_bound)
+ self.update_index = 0
+
+ def reset(self):
+ self.box_list.clear()
+ self.index = 0
+ self.container.reset()
+ self.boxes = self.container.boxes
+ self.default_box_set = []
+ for box in self.boxes:
+ self.default_box_set.append([box.x, box.y, box.z])
+ self.default_box_set.append([10, 10, 10])
+ self.box_set = self.default_box_set
+
+ def generate_box_size(self, **kwargs):
+ self.box_list.append(self.box_set[self.index])
+ self.index += 1
diff --git a/envs/Packing/utils.py b/envs/Packing/utils.py
new file mode 100755
index 0000000..78507f7
--- /dev/null
+++ b/envs/Packing/utils.py
@@ -0,0 +1,220 @@
+# -*- coding: utf-8 -*-
+
+import copy
+from collections import namedtuple
+import time
+
+import numpy as np
+
+Rectangle = namedtuple("Rectangle", "top bottom left right")
+
+
+class Projection:
+ """ Projection direction """
+ Xy = 0
+ Xz = 1
+ Yx = 2
+ Yz = 3
+ Zx = 4
+ Zy = 5
+
+
+def obs_to_list(observation, container_size):
+ if not isinstance(observation, np.ndarray):
+ obs_info = observation.cpu().numpy()
+ else:
+ obs_info = observation
+
+ obs_info = obs_info.reshape((4, -1))
+ new_box_l = int(obs_info[1][0])
+ new_box_w = int(obs_info[2][0])
+ new_box_h = int(obs_info[3][0])
+
+ plain = obs_info[0].reshape((container_size[0], container_size[1]))
+
+ return plain, (new_box_l, new_box_w, new_box_h)
+
+
+def find_rectangles(arr: list, height: int) -> list:
+ """
+ detect rectangles from 2d list (height map)
+ Args:
+ arr: 2d list
+ Returns:
+ rectangles: list of Rectangle
+ """
+ # Deeply copy the array so that it can be modified safely
+ arr = [row[:] for row in arr]
+
+ rectangles = []
+
+ for top, row in enumerate(arr):
+ start = 0
+
+ # Look for rectangles whose top row is here
+ while True:
+ try:
+ left = row.index(0, start)
+ except ValueError:
+ break
+
+ # Set start to one past the last 0 in the contiguous line of 0s
+ try:
+ start = row.index(1, left)
+ except ValueError:
+ start = len(row)
+
+ right = start - 1
+
+ # if ( # Width == 1
+ # left == right): # or
+ # # There are 0s above
+ # # top > 0 and not all(arr[top - 1][left:right + 1])):
+ # continue
+
+ bottom = top + 1
+ while (bottom < len(arr) and
+ # No extra zeroes on the sides
+ # (left == 0 or arr[bottom][left - 1]) and
+ # (right == len(row) - 1 or arr[bottom][right + 1]) and
+ # All zeroes in the row
+ not any(arr[bottom][left:right + 1])):
+ bottom += 1
+
+ # The loop ends when bottom has gone too far, so backtrack
+ bottom -= 1
+
+ # if ( # Height == 1
+ # bottom == top): # or
+ # # There are 0s beneath
+ # # (bottom < len(arr) - 1 and
+ # # not all(arr[bottom + 1][left:right + 1]))):
+ # continue
+
+ # rectangles.append(Rectangle(top, bottom, left, right))
+ rectangles.append([bottom - top + 1, right - left + 1, height, top, left, 0])
+
+ # Remove the rectangle so that it doesn't affect future searches
+ for i in range(top, bottom + 1):
+ arr[i][left:right + 1] = [1] * (right + 1 - left)
+
+ return rectangles
+
+
+def extract_items_from_heightmap(observation: np.ndarray) -> list:
+ """
+ Args:
+ observation: heightmap 2d array
+
+ Returns:
+ item_list: list of items extracted from current height map
+ """
+ # time_start = time.time()
+
+ height_arr = copy.deepcopy(observation)
+ height_value = np.unique(height_arr)
+ height_value_without_zero = height_value[np.nonzero(height_value)]
+
+ rectangles = [] # a rectangle represents a item
+ item_list = [] # extracted items
+
+ # distinguish objects by height
+ for height in height_value_without_zero:
+ state_height = height_arr - height
+ state_height[np.nonzero(state_height)] = 1
+ state_height_list = state_height.tolist()
+
+ rectangles = find_rectangles(state_height_list, height)
+ # rectangles = sweep(state_height_list)
+ # print("rectangles: ", rectangles)
+
+ # for rect in rectangles:
+ # length = rect.bottom - rect.top + 1
+ # width = rect.right - rect.left + 1
+
+ # item = [length, width, height, rect.top, rect.left, 0]
+ # item_list.append(item)
+ item_list.extend(rectangles) # (length, width, height, x, y, z)
+
+ # time_end = time.time()
+ # print("time cost: ", (time_end - time_start) * 1000, "ms")
+ return item_list
+
+
+def can_take_projection(new_item, placed_item, ep_dir: int, proj_dir: int) -> bool:
+
+ """
+ function returning true if an EP(generation direction: ep_dir) of item k can be projected(projection direction:
+ proj_dir) on the item i.
+ :param new_item:
+ :param placed_item:
+ :param ep_dir: (number 0,1,2 corresponding to x, y, z), extreme point generation direction
+ :param proj_dir: extreme point projection direction
+ :return: bool
+ """
+ new_dim = new_item[:3]
+ placed_dim = placed_item[:3]
+
+ remain_dir = 3 - ep_dir - proj_dir
+ epsilon = 0.0
+ proj_flag = True
+
+ if placed_item[-3:][proj_dir] + placed_dim[proj_dir] > new_item[-3:][proj_dir] - epsilon:
+ # i.e. piece is further from axis in projection direction
+ proj_flag = False
+ return proj_flag
+
+ if placed_item[-3:][ep_dir] > new_item[-3:][ep_dir] + new_dim[ep_dir] - epsilon:
+ # i.e. piece too far
+ proj_flag = False
+ return proj_flag
+
+ if placed_item[-3:][ep_dir] + placed_dim[ep_dir] < new_item[-3:][ep_dir] + new_dim[ep_dir] + epsilon:
+ # i.e. piece not far enough
+ proj_flag = False
+ return proj_flag
+
+ if placed_item[-3:][remain_dir] > new_item[-3:][remain_dir] - epsilon:
+ # i.e. piece too far
+ proj_flag = False
+ return proj_flag
+
+ if placed_item[-3:][remain_dir] + placed_dim[remain_dir] < new_item[-3:][remain_dir] + epsilon:
+ # i.e. piece not far enough
+ proj_flag = False
+ return proj_flag
+
+ return proj_flag
+
+
+
+if __name__ == "__main__":
+ # test find_rectangles
+ state = [[15, 15, 15, 15, 15, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 18, 0, 0, 0, 0],
+ [15, 15, 15, 15, 15, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 18, 0, 0, 0, 0],
+ [15, 15, 15, 15, 15, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 18, 0, 0, 0, 0],
+ [15, 15, 15, 15, 15, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 18, 0, 0, 0, 0],
+ [15, 15, 15, 15, 15, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 18, 0, 0, 0, 0],
+ [15, 15, 15, 15, 15, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 18, 0, 0, 0, 0],
+ [15, 15, 15, 15, 15, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 18, 0, 0, 0, 0],
+ [15, 15, 15, 15, 15, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 18, 0, 0, 0, 0],
+ [24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 15, 15, 15, 15, 15, 5, 0, 0, 0, 0],
+ [24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 15, 15, 15, 15, 15, 5, 0, 0, 0, 0],
+ [24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 15, 15, 15, 15, 15, 5, 0, 0, 0, 0],
+ [24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 15, 15, 15, 15, 15, 5, 0, 0, 0, 0],
+ [24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 15, 15, 15, 15, 15, 5, 0, 0, 0, 0],
+ [24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 15, 15, 15, 15, 15, 0, 0, 0, 0, 0],
+ [24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 15, 15, 15, 15, 15, 0, 0, 0, 0, 0],
+ [24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 15, 15, 15, 15, 15, 0, 0, 0, 0, 0],
+ [22, 22, 22, 22, 22, 12, 12, 12, 12, 12, 15, 15, 15, 15, 15, 15, 0, 0, 0, 0],
+ [22, 22, 22, 22, 22, 12, 12, 12, 12, 12, 15, 15, 15, 15, 15, 15, 0, 0, 0, 0],
+ [22, 22, 22, 22, 22, 12, 12, 12, 12, 12, 15, 15, 15, 15, 15, 15, 0, 0, 0, 0],
+ [22, 22, 22, 22, 22, 12, 12, 12, 12, 12, 15, 15, 15, 15, 15, 15, 0, 0, 0, 0],
+ [22, 22, 22, 22, 22, 12, 12, 12, 12, 12, 15, 15, 15, 15, 15, 15, 0, 0, 0, 0],
+ [22, 22, 22, 22, 22, 12, 12, 12, 12, 12, 10, 10, 10, 10, 10, 0, 0, 0, 0, 0],
+ [22, 22, 22, 22, 22, 12, 12, 12, 12, 12, 10, 10, 10, 10, 10, 0, 0, 0, 0, 0],
+ [22, 22, 22, 22, 22, 12, 12, 12, 12, 12, 10, 10, 10, 10, 10, 0, 0, 0, 0, 0],
+ ]
+
+ rectangles = extract_items_from_heightmap(np.array(state))
+ print(rectangles)
diff --git a/envs/__init__.py b/envs/__init__.py
new file mode 100644
index 0000000..8b13789
--- /dev/null
+++ b/envs/__init__.py
@@ -0,0 +1 @@
+
diff --git a/images/demo.gif b/images/demo.gif
new file mode 100644
index 0000000..bfa56f6
Binary files /dev/null and b/images/demo.gif differ
diff --git a/images/overview.png b/images/overview.png
new file mode 100644
index 0000000..032b55c
Binary files /dev/null and b/images/overview.png differ
diff --git a/masked_a2c.py b/masked_a2c.py
new file mode 100644
index 0000000..26b2d07
--- /dev/null
+++ b/masked_a2c.py
@@ -0,0 +1,192 @@
+# https://github.com/davidluozhang/BasedRL
+# adapted from https://github.com/thu-ml/tianshou/issues/742
+
+from typing import Any, Dict, List, Optional, Type, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from tianshou.data import Batch, ReplayBuffer, to_torch_as
+from tianshou.policy import A2CPolicy
+from tianshou.utils.net.common import ActorCritic
+
+
+class MaskedA2CPolicy(A2CPolicy):
+ """Implementation of Synchronous Advantage Actor-Critic. arXiv:1602.01783.
+
+ :param torch.nn.Module actor: the actor network following the rules in
+ :class:`~tianshou.policy.BasePolicy`. (s -> logits)
+ :param torch.nn.Module critic: the critic network. (s -> V(s))
+ :param torch.optim.Optimizer optim: the optimizer for actor and critic network.
+ :param dist_fn: distribution class for computing the action.
+ :type dist_fn: Type[torch.distributions.Distribution]
+ :param float discount_factor: in [0, 1]. Default to 0.99.
+ :param float vf_coef: weight for value loss. Default to 0.5.
+ :param float ent_coef: weight for entropy loss. Default to 0.01.
+ :param float max_grad_norm: clipping gradients in back propagation. Default to
+ None.
+ :param float gae_lambda: in [0, 1], param for Generalized Advantage Estimation.
+ Default to 0.95.
+ :param bool reward_normalization: normalize estimated values to have std close to
+ 1. Default to False.
+ :param int max_batchsize: the maximum size of the batch when computing GAE,
+ depends on the size of available memory and the memory cost of the
+ model; should be as large as possible within the memory constraint.
+ Default to 256.
+ :param bool action_scaling: whether to map actions from range [-1, 1] to range
+ [action_spaces.low, action_spaces.high]. Default to True.
+ :param str action_bound_method: method to bound action to range [-1, 1], can be
+ either "clip" (for simply clipping the action), "tanh" (for applying tanh
+ squashing) for now, or empty string for no bounding. Default to "clip".
+ :param Optional[gym.Space] action_space: env's action space, mandatory if you want
+ to use option "action_scaling" or "action_bound_method". Default to None.
+ :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
+ optimizer in each policy.update(). Default to None (no lr_scheduler).
+ :param bool deterministic_eval: whether to use deterministic action instead of
+ stochastic action sampled by the policy. Default to False.
+
+ .. seealso::
+
+ Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
+ explanation.
+ """
+
+ def __init__(
+ self,
+ actor: torch.nn.Module,
+ critic: torch.nn.Module,
+ optim: torch.optim.Optimizer,
+ dist_fn: Type[torch.distributions.Distribution],
+ vf_coef: float = 0.5,
+ ent_coef: float = 0.01,
+ max_grad_norm: Optional[float] = None,
+ gae_lambda: float = 0.95,
+ max_batchsize: int = 256,
+ **kwargs: Any
+ ) -> None:
+ super().__init__(
+ actor,
+ critic,
+ optim,
+ dist_fn,
+ vf_coef,
+ ent_coef,
+ max_grad_norm,
+ gae_lambda,
+ max_batchsize,
+ **kwargs
+ )
+
+ def _compute_returns(
+ self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
+ ) -> Batch:
+ v_s, v_s_ = [], []
+ with torch.no_grad():
+ for minibatch in batch.split(self._batch, shuffle=False, merge_last=True):
+ v_s.append(self.critic(minibatch.obs.obs))
+ v_s_.append(self.critic(minibatch.obs_next.obs))
+ batch.v_s = torch.cat(v_s, dim=0).flatten() # old value
+ v_s = batch.v_s.cpu().numpy()
+ v_s_ = torch.cat(v_s_, dim=0).flatten().cpu().numpy()
+ # when normalizing values, we do not minus self.ret_rms.mean to be numerically
+ # consistent with OPENAI baselines' value normalization pipeline. Emperical
+ # study also shows that "minus mean" will harm performances a tiny little bit
+ # due to unknown reasons (on Mujoco envs, not confident, though).
+ if self._rew_norm: # unnormalize v_s & v_s_
+ v_s = v_s * np.sqrt(self.ret_rms.var + self._eps)
+ v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps)
+ unnormalized_returns, advantages = self.compute_episodic_return(
+ batch,
+ buffer,
+ indices,
+ v_s_,
+ v_s,
+ gamma=self._gamma,
+ gae_lambda=self._lambda
+ )
+ if self._rew_norm:
+ batch.returns = unnormalized_returns / \
+ np.sqrt(self.ret_rms.var + self._eps)
+ self.ret_rms.update(unnormalized_returns)
+ else:
+ batch.returns = unnormalized_returns
+ batch.returns = to_torch_as(batch.returns, batch.v_s)
+ batch.adv = to_torch_as(advantages, batch.v_s)
+ return batch
+
+ def learn( # type: ignore
+ self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
+ ) -> Dict[str, List[float]]:
+ losses, actor_losses, vf_losses, ent_losses = [], [], [], []
+ for _ in range(repeat):
+ for minibatch in batch.split(batch_size, merge_last=True):
+ # calculate loss for actor
+ dist = self(minibatch).dist
+ log_prob = dist.log_prob(minibatch.act)
+ log_prob = log_prob.reshape(len(minibatch.adv), -1).transpose(0, 1)
+ actor_loss = -(log_prob * minibatch.adv).mean()
+ # calculate loss for critic
+ value = self.critic(minibatch.obs.obs).flatten()
+ vf_loss = F.mse_loss(minibatch.returns, value)
+ # calculate regularization and overall loss
+ ent_loss = dist.entropy().mean()
+ loss = actor_loss + self._weight_vf * vf_loss \
+ - self._weight_ent * ent_loss
+ self.optim.zero_grad()
+ loss.backward()
+ if self._grad_norm: # clip large gradient
+ nn.utils.clip_grad_norm_(
+ self._actor_critic.parameters(), max_norm=self._grad_norm
+ )
+ self.optim.step()
+ actor_losses.append(actor_loss.item())
+ vf_losses.append(vf_loss.item())
+ ent_losses.append(ent_loss.item())
+ losses.append(loss.item())
+
+ return {
+ "loss": losses,
+ "loss/actor": actor_losses,
+ "loss/vf": vf_losses,
+ "loss/ent": ent_losses,
+ }
+
+
+ def forward(
+ self,
+ batch: Batch,
+ state: Optional[Union[dict, Batch, np.ndarray]] = None,
+ **kwargs: Any,
+ ) -> Batch:
+ """Compute action over the given batch data.
+
+ :return: A :class:`~tianshou.data.Batch` which has 4 keys:
+
+ * ``act`` the action.
+ * ``logits`` the network's raw output.
+ * ``dist`` the action distribution.
+ * ``state`` the hidden state.
+
+ .. seealso::
+
+ Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
+ more detailed explanation.
+ """
+ logits, hidden = self.actor(batch.obs, state=state)
+ # mask = torch.as_tensor(batch.obs.mask, dtype=bool, device=logits.device)
+ mask = torch.FloatTensor(batch.obs.mask).to(logits.device)
+
+ if isinstance(logits, tuple):
+ dist = self.dist_fn(*logits)
+ else:
+ dist = self.dist_fn(logits=logits, masks=mask)
+ if self._deterministic_eval and not self.training:
+ if self.action_type == "discrete":
+ act = dist.logits.argmax(-1)
+ elif self.action_type == "continuous":
+ act = logits[0]
+ else:
+ act = dist.sample()
+ return Batch(logits=logits, act=act, state=hidden, dist=dist)
diff --git a/masked_ppo.py b/masked_ppo.py
new file mode 100644
index 0000000..74ff5df
--- /dev/null
+++ b/masked_ppo.py
@@ -0,0 +1,223 @@
+# https://github.com/davidluozhang/BasedRL
+# adapted from https://github.com/thu-ml/tianshou/issues/742
+
+from typing import Any, Dict, List, Optional, Type, Union
+
+import numpy as np
+import torch
+from torch import nn
+
+from tianshou.data import Batch, ReplayBuffer, to_torch_as
+from tianshou.policy import A2CPolicy, PPOPolicy
+from tianshou.utils.net.common import ActorCritic
+
+
+class MaskedPPOPolicy(PPOPolicy):
+ r"""Implementation of Proximal Policy Optimization. arXiv:1707.06347.
+
+ :param torch.nn.Module actor: the actor network following the rules in
+ :class:`~tianshou.policy.BasePolicy`. (s -> logits)
+ :param torch.nn.Module critic: the critic network. (s -> V(s))
+ :param torch.optim.Optimizer optim: the optimizer for actor and critic network.
+ :param dist_fn: distribution class for computing the action.
+ :type dist_fn: Type[torch.distributions.Distribution]
+ :param float discount_factor: in [0, 1]. Default to 0.99.
+ :param float eps_clip: :math:`\epsilon` in :math:`L_{CLIP}` in the original
+ paper. Default to 0.2.
+ :param float dual_clip: a parameter c mentioned in arXiv:1912.09729 Equ. 5,
+ where c > 1 is a constant indicating the lower bound.
+ Default to 5.0 (set None if you do not want to use it).
+ :param bool value_clip: a parameter mentioned in arXiv:1811.02553v3 Sec. 4.1.
+ Default to True.
+ :param bool advantage_normalization: whether to do per mini-batch advantage
+ normalization. Default to True.
+ :param bool recompute_advantage: whether to recompute advantage every update
+ repeat according to https://arxiv.org/pdf/2006.05990.pdf Sec. 3.5.
+ Default to False.
+ :param float vf_coef: weight for value loss. Default to 0.5.
+ :param float ent_coef: weight for entropy loss. Default to 0.01.
+ :param float max_grad_norm: clipping gradients in back propagation. Default to
+ None.
+ :param float gae_lambda: in [0, 1], param for Generalized Advantage Estimation.
+ Default to 0.95.
+ :param bool reward_normalization: normalize estimated values to have std close
+ to 1, also normalize the advantage to Normal(0, 1). Default to False.
+ :param int max_batchsize: the maximum size of the batch when computing GAE,
+ depends on the size of available memory and the memory cost of the model;
+ should be as large as possible within the memory constraint. Default to 256.
+ :param bool action_scaling: whether to map actions from range [-1, 1] to range
+ [action_spaces.low, action_spaces.high]. Default to True.
+ :param str action_bound_method: method to bound action to range [-1, 1], can be
+ either "clip" (for simply clipping the action), "tanh" (for applying tanh
+ squashing) for now, or empty string for no bounding. Default to "clip".
+ :param Optional[gym.Space] action_space: env's action space, mandatory if you want
+ to use option "action_scaling" or "action_bound_method". Default to None.
+ :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
+ optimizer in each policy.update(). Default to None (no lr_scheduler).
+ :param bool deterministic_eval: whether to use deterministic action instead of
+ stochastic action sampled by the policy. Default to False.
+
+ .. seealso::
+
+ Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
+ explanation.
+ """
+
+ def __init__(
+ self,
+ actor: torch.nn.Module,
+ critic: torch.nn.Module,
+ optim: torch.optim.Optimizer,
+ dist_fn: Type[torch.distributions.Distribution],
+ eps_clip: float = 0.2,
+ dual_clip: Optional[float] = None,
+ value_clip: bool = False,
+ advantage_normalization: bool = True,
+ recompute_advantage: bool = False,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(actor, critic, optim, dist_fn, **kwargs)
+ self._eps_clip = eps_clip
+ assert dual_clip is None or dual_clip > 1.0, \
+ "Dual-clip PPO parameter should greater than 1.0."
+ self._dual_clip = dual_clip
+ self._value_clip = value_clip
+ if not self._rew_norm:
+ assert not self._value_clip, \
+ "value clip is available only when `reward_normalization` is True"
+ self._norm_adv = advantage_normalization
+ self._recompute_adv = recompute_advantage
+ self._actor_critic: ActorCritic
+
+ def learn( # type: ignore
+ self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
+ ) -> Dict[str, List[float]]:
+ losses, clip_losses, vf_losses, ent_losses = [], [], [], []
+ for step in range(repeat):
+ if self._recompute_adv and step > 0:
+ batch = self._compute_returns(batch, self._buffer, self._indices)
+ for minibatch in batch.split(batch_size, merge_last=True):
+ # calculate loss for actor
+ dist = self(minibatch).dist
+ if self._norm_adv:
+ mean, std = minibatch.adv.mean(), minibatch.adv.std()
+ minibatch.adv = (minibatch.adv -
+ mean) / (std + self._eps) # per-batch norm
+ ratio = (dist.log_prob(minibatch.act) -
+ minibatch.logp_old).exp().float()
+ ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
+ surr1 = ratio * minibatch.adv
+ surr2 = ratio.clamp(
+ 1.0 - self._eps_clip, 1.0 + self._eps_clip
+ ) * minibatch.adv
+ if self._dual_clip:
+ clip1 = torch.min(surr1, surr2)
+ clip2 = torch.max(clip1, self._dual_clip * minibatch.adv)
+ clip_loss = -torch.where(minibatch.adv < 0, clip2, clip1).mean()
+ else:
+ clip_loss = -torch.min(surr1, surr2).mean()
+ # calculate loss for critic
+ value = self.critic(minibatch.obs).flatten()
+ if self._value_clip:
+ v_clip = minibatch.v_s + \
+ (value - minibatch.v_s).clamp(-self._eps_clip, self._eps_clip)
+ vf1 = (minibatch.returns - value).pow(2)
+ vf2 = (minibatch.returns - v_clip).pow(2)
+ vf_loss = torch.max(vf1, vf2).mean()
+ else:
+ vf_loss = (minibatch.returns - value).pow(2).mean()
+ # calculate regularization and overall loss
+ ent_loss = dist.entropy().mean()
+ loss = clip_loss + self._weight_vf * vf_loss - self._weight_ent * ent_loss
+ self.optim.zero_grad()
+ loss.backward()
+ if self._grad_norm: # clip large gradient
+ nn.utils.clip_grad_norm_(
+ self._actor_critic.parameters(), max_norm=self._grad_norm,
+ )
+ self.optim.step()
+ clip_losses.append(clip_loss.item())
+ vf_losses.append(vf_loss.item())
+ ent_losses.append(ent_loss.item())
+ losses.append(loss.item())
+
+ return {
+ "loss": losses,
+ "loss/clip": clip_losses,
+ "loss/vf": vf_losses,
+ "loss/ent": ent_losses,
+ }
+
+ def _compute_returns(
+ self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
+ ) -> Batch:
+ v_s, v_s_ = [], []
+ with torch.no_grad():
+ for minibatch in batch.split(self._batch, shuffle=False, merge_last=True):
+ v_s.append(self.critic(minibatch.obs))
+ v_s_.append(self.critic(minibatch.obs_next))
+ batch.v_s = torch.cat(v_s, dim=0).flatten() # old value
+ v_s = batch.v_s.cpu().numpy()
+ v_s_ = torch.cat(v_s_, dim=0).flatten().cpu().numpy()
+ # when normalizing values, we do not minus self.ret_rms.mean to be numerically
+ # consistent with OPENAI baselines' value normalization pipeline. Emperical
+ # study also shows that "minus mean" will harm performances a tiny little bit
+ # due to unknown reasons (on Mujoco envs, not confident, though).
+ if self._rew_norm: # unnormalize v_s & v_s_
+ v_s = v_s * np.sqrt(self.ret_rms.var + self._eps)
+ v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps)
+ unnormalized_returns, advantages = self.compute_episodic_return(
+ batch,
+ buffer,
+ indices,
+ v_s_,
+ v_s,
+ gamma=self._gamma,
+ gae_lambda=self._lambda
+ )
+ if self._rew_norm:
+ batch.returns = unnormalized_returns / \
+ np.sqrt(self.ret_rms.var + self._eps)
+ self.ret_rms.update(unnormalized_returns)
+ else:
+ batch.returns = unnormalized_returns
+ batch.returns = to_torch_as(batch.returns, batch.v_s)
+ batch.adv = to_torch_as(advantages, batch.v_s)
+ return batch
+
+ def forward(
+ self,
+ batch: Batch,
+ state: Optional[Union[dict, Batch, np.ndarray]] = None,
+ **kwargs: Any,
+ ) -> Batch:
+ """Compute action over the given batch data.
+
+ :return: A :class:`~tianshou.data.Batch` which has 4 keys:
+
+ * ``act`` the action.
+ * ``logits`` the network's raw output.
+ * ``dist`` the action distribution.
+ * ``state`` the hidden state.
+
+ .. seealso::
+
+ Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
+ more detailed explanation.
+ """
+ logits, hidden = self.actor(batch.obs, state=state)
+ # mask = torch.as_tensor(batch.obs.mask, dtype=bool, device=logits.device)
+ mask = torch.FloatTensor(batch.obs.mask).to(logits.device)
+
+ if isinstance(logits, tuple):
+ dist = self.dist_fn(*logits)
+ else:
+ dist = self.dist_fn(logits=logits, masks=mask)
+ if self._deterministic_eval and not self.training:
+ if self.action_type == "discrete":
+ act = dist.logits.argmax(-1)
+ elif self.action_type == "continuous":
+ act = logits[0]
+ else:
+ act = dist.sample()
+ return Batch(logits=logits, act=act, state=hidden, dist=dist)
diff --git a/model.py b/model.py
new file mode 100644
index 0000000..a6630fb
--- /dev/null
+++ b/model.py
@@ -0,0 +1,327 @@
+
+from typing import Any, Dict, Optional, Sequence, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+
+class Attention(nn.Module):
+ def __init__(self, embed_size, heads):
+ super(Attention, self).__init__()
+ self.embed_size = embed_size
+ self.heads = heads
+ self.head_dim = embed_size // heads
+
+ assert (
+ self.head_dim * heads == embed_size
+ ), "Embedding size needs to be divisible by heads"
+
+ self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
+ self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
+ self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
+ self.fc_out = nn.Linear(heads * self.head_dim, embed_size)
+
+ def forward(self, query, keys, values, pad_mask=None):
+ # A.P.: Get number of training examples
+ N = query.shape[0]
+
+ value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
+
+ #A.P.: Split the embedding into self.heads different pieces
+ values = values.reshape(N, value_len, self.heads, self.head_dim)
+ keys = keys.reshape(N, key_len, self.heads, self.head_dim)
+ query = query.reshape(N, query_len, self.heads, self.head_dim)
+
+ values = self.values(values) # A.P.: (N, value_len, heads, head_dim)
+ keys = self.keys(keys) # A.P.: (N, key_len, heads, head_dim)
+ queries = self.queries(query) # A.P.: (N, query_len, heads, heads_dim)
+
+ # A.P.: Einsum does matrix mult. for query*keys for each training example
+ # with every other training example, don't be confused by einsum
+ # it's just how I like doing matrix multiplication & bmm
+
+ energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
+ # A.P.: queries shape: (N, query_len, heads, heads_dim),
+ # A.P.: keys shape: (N, key_len, heads, heads_dim)
+ # A.P.: energy: (N, heads, query_len, key_len)
+
+ # Mask padded indices so their weights become 0
+ if pad_mask is not None:
+ pad_mask = pad_mask.unsqueeze(-1).expand(N, query_len, key_len)
+ pad_mask = pad_mask.unsqueeze(1).repeat(1, self.heads, 1, 1)
+ energy = energy.masked_fill(pad_mask==0, -1e18)
+ # energy = energy.masked_fill(pad_mask==0, float("-inf"))
+
+ # A.P.: Normalize energy values similarly to seq2seq + attention
+ # so that they sum to 1. Also divide by scaling factor for
+ # better stability
+ attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)
+ # A.P.: attention shape: (N, heads, query_len, key_len)
+
+ out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
+ N, query_len, self.heads * self.head_dim
+ )
+ # A.P.: attention shape: (N, heads, query_len, key_len)
+ # A.P.: values shape: (N, value_len, heads, heads_dim)
+ # A.P.: out after matrix multiply: (N, query_len, heads, head_dim), then
+ # we reshape and flatten the last two dimensions.
+
+ out = self.fc_out(out)
+ # A.P.: Linear layer doesn't modify the shape, final shape will be (N, query_len, embed_size)
+
+ return out
+
+
+class TransformerBlock(nn.Module):
+ def __init__(self, embed_size, heads, dropout, forward_expansion):
+ super(TransformerBlock, self).__init__()
+ self.attention = Attention(embed_size, heads)
+ self.norm1 = nn.LayerNorm(embed_size)
+ self.norm2 = nn.LayerNorm(embed_size)
+
+ self.feed_forward = nn.Sequential(
+ nn.Linear(embed_size, forward_expansion * embed_size),
+ nn.ReLU(),
+ nn.Linear(forward_expansion * embed_size, embed_size),
+ )
+
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, query, key, value, pad_mask=None):
+ attention = self.attention(query, key, value, pad_mask)
+
+ # A.P.: Add skip connection, run through normalization and finally dropout
+ x = self.dropout(self.norm1(attention + query))
+ forward = self.feed_forward(x)
+ out = self.dropout(self.norm2(forward + x))
+ return out
+
+
+class EncoderBlock(nn.Module):
+ def __init__(self, embed_size, heads, forward_expansion, dropout):
+ super(EncoderBlock, self).__init__()
+
+ self.item_embedding = TransformerBlock(embed_size, heads, dropout, forward_expansion)
+ self.ems_embedding = TransformerBlock(embed_size, heads, dropout, forward_expansion)
+ self.ems_on_item = TransformerBlock(embed_size, heads, dropout, forward_expansion)
+ self.item_on_ems = TransformerBlock(embed_size, heads, dropout, forward_expansion)
+
+ def forward(self, item_feature, ems_feature, mask=None):
+ # self-attention
+ item_embedding = self.item_embedding(item_feature, item_feature, item_feature)
+ ems_embedding = self.ems_embedding(ems_feature, ems_feature, ems_feature, mask)
+ # cross-attention
+ ems_on_item = self.ems_on_item(ems_embedding, item_embedding, item_embedding, mask)
+ item_on_ems = self.item_on_ems(item_embedding, ems_embedding, ems_embedding)
+
+ return item_on_ems, ems_on_item
+
+
+class ActorHead(nn.Module):
+ def __init__(
+ self,
+ preprocess_net: nn.Module,
+ embed_size: int,
+ padding_mask: bool = False,
+ device: Union[str, int, torch.device] = "cpu",
+ ) -> None:
+ super().__init__()
+ self.padding_mask = padding_mask
+ self.device = device
+ self.preprocess = preprocess_net
+ self.layer_1 = nn.Sequential(
+ init_(nn.Linear(embed_size, embed_size)),
+ nn.LeakyReLU(),
+ )
+ self.layer_2 = nn.Sequential(
+ init_(nn.Linear(embed_size, embed_size)),
+ nn.LeakyReLU(),
+ )
+
+ def forward(
+ self,
+ obs: Dict,
+ state: Any = None,
+ info: Dict[str, Any] = {}
+ ) -> Tuple[torch.Tensor, Any]:
+ batch_size = obs.obs.shape[0]
+
+ if self.padding_mask:
+ mask = torch.as_tensor(obs.mask, dtype=torch.bool, device=self.device)
+ mask = torch.sum(mask.reshape(batch_size, -1, 2), dim=-1).bool()
+ else:
+ mask = None
+
+ item_embedding, ems_embedding, hidden = self.preprocess(obs.obs, state, mask)
+ item_embedding = self.layer_1(item_embedding)
+ ems_embedding = self.layer_2(ems_embedding).permute(0, 2, 1)
+
+ logits = torch.bmm(item_embedding, ems_embedding).reshape(batch_size, -1)
+
+ return logits, hidden
+
+
+class CriticHead(nn.Module):
+ def __init__(
+ self,
+ k_placement: int,
+ preprocess_net: nn.Module,
+ embed_size: int,
+ padding_mask: bool = False,
+ device: Union[str, int, torch.device] = "cpu",
+ ) -> None:
+ super().__init__()
+ self.padding_mask = padding_mask
+ self.device = device
+ self.preprocess = preprocess_net
+ self.k_placement = k_placement
+ self.layer_1 = nn.Sequential(
+ init_(nn.Linear(embed_size, embed_size)),
+ nn.LeakyReLU(),
+ )
+ self.layer_2 = nn.Sequential(
+ init_(nn.Linear(embed_size, embed_size)),
+ nn.LeakyReLU(),
+ )
+ self.layer_3 = nn.Sequential(
+ init_(nn.Linear(2 * embed_size, embed_size)),
+ nn.LeakyReLU(),
+ init_(nn.Linear(embed_size, embed_size)),
+ nn.LeakyReLU(),
+ init_(nn.Linear(embed_size, 1))
+ )
+
+ def forward(
+ self,
+ obs: Union[np.ndarray, torch.Tensor],
+ **kwargs: Any
+ ) -> torch.Tensor:
+ batch_size = obs.shape[0]
+ mask = torch.as_tensor(obs.mask, dtype=torch.bool, device=self.device)
+ mask = torch.sum(mask.reshape(batch_size, -1, 2), dim=-1).bool()
+ if self.padding_mask:
+ item_embedding, ems_embedding, _ = self.preprocess(obs.obs, mask)
+ else:
+ item_embedding, ems_embedding, _ = self.preprocess(obs.obs)
+
+ item_embedding = self.layer_1(item_embedding)
+ ems_embedding = self.layer_2(ems_embedding)
+
+ item_embedding = torch.sum(item_embedding, dim=-2)
+ ems_embedding = torch.sum(ems_embedding * mask[..., None], dim=-2)
+
+ joint_embedding = torch.cat((item_embedding, ems_embedding), dim=-1)
+
+ state_value = self.layer_3(joint_embedding)
+ return state_value
+
+
+class ShareNet(nn.Module):
+ def __init__(
+ self,
+ k_placement: int = 100,
+ box_max_size: int = 5,
+ container_size: Sequence[int] = [10, 10, 10],
+ embed_size: int = 32,
+ num_layers: int = 6,
+ forward_expansion: int = 4,
+ heads: int = 6,
+ dropout: float = 0,
+ device: Union[str, int, torch.device] = "cpu",
+ place_gen: str = "EMS",
+ ) -> None:
+ super().__init__()
+
+ self.device = device
+ self.k_placement = k_placement
+ self.container_size = container_size
+ self.place_gen = place_gen
+ if place_gen == "EMS":
+ input_size = 6
+ else:
+ input_size = 3
+
+ self.factor = 1 / max(container_size)
+
+ self.item_encoder = nn.Sequential(
+ init_(nn.Linear(3, 32)),
+ nn.LeakyReLU(),
+ init_(nn.Linear(32, embed_size)),
+ )
+
+ self.placement_encoder = nn.Sequential(
+ init_(nn.Linear(input_size, 32)),
+ nn.LeakyReLU(),
+ init_(nn.Linear(32, embed_size)),
+ )
+
+ self.backbone = nn.ModuleList(
+ [
+ EncoderBlock(
+ embed_size=embed_size,
+ heads=heads,
+ dropout=dropout,
+ forward_expansion=forward_expansion,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ def forward(
+ self,
+ obs: Union[np.ndarray, torch.Tensor],
+ state: Any = None,
+ mask: Union[np.ndarray, torch.Tensor] = None
+ ) -> Tuple[torch.Tensor, Any]:
+ if not isinstance(obs, torch.Tensor):
+ obs = torch.as_tensor(obs, dtype=torch.float32, device=self.device) * self.factor
+ if not isinstance(mask, torch.Tensor) and mask is not None:
+ mask = torch.as_tensor(mask, dtype=torch.float32, device=self.device) # (batch_size, k_placement)
+
+ obs_hm, obs_next, obs_placements = obs2input(obs, self.container_size, self.place_gen)
+
+ item_embedding = self.item_encoder(obs_next) # (batch_size, 2, emded_size)
+ placement_embedding = self.placement_encoder(obs_placements) # (batch_size, k_placement, emded_size)
+
+ for layer in self.backbone:
+ item_embedding, placement_embedding = layer(item_embedding, placement_embedding, mask)
+
+ return item_embedding, placement_embedding, state
+
+
+def obs2input(
+ obs: torch.Tensor,
+ container_size: Sequence[int],
+ place_gen: str = "EMS",
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ convert obsversation to input of the network
+
+ Returns:
+ hm: (batch, 1, L, W)
+ next_size: (batch, 2, 3)
+ placements: (batch, k_placement, 6)
+ """
+ batch_size = obs.shape[0]
+ hm = obs[:, :container_size[0]*container_size[1]].reshape((batch_size, 1, container_size[0], container_size[1]))
+ next_size = obs[:, container_size[0]*container_size[1]:container_size[0]*container_size[1] + 6]
+ # [[l, w, h], [w, l, h]]
+ next_size = next_size.reshape((batch_size, 2, 3))
+
+ if place_gen == "EMS":
+ # (x_1, y_1, z_1, x_2, y_2, H)
+ placements = obs[:, container_size[0]*container_size[1] + 6:].reshape((batch_size, -1, 6))
+ else:
+ placements = obs[:, container_size[0]*container_size[1] + 6:].reshape((batch_size, -1, 3))
+
+ return hm, next_size, placements
+
+
+def init(module, weight_init, bias_init, gain=1):
+ weight_init(module.weight.data, gain=gain)
+ bias_init(module.bias.data)
+ return module
+
+init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), nn.init.calculate_gain('leaky_relu'))
+
diff --git a/mycollector.py b/mycollector.py
new file mode 100644
index 0000000..277733e
--- /dev/null
+++ b/mycollector.py
@@ -0,0 +1,309 @@
+import time
+import warnings
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import gymnasium as gym
+import numpy as np
+import torch
+
+from tianshou.data import (
+ Batch,
+ CachedReplayBuffer,
+ ReplayBuffer,
+ ReplayBufferManager,
+ VectorReplayBuffer,
+ to_numpy,
+ Collector
+)
+from tianshou.data.batch import _alloc_by_keys_diff
+from tianshou.env import BaseVectorEnv, DummyVectorEnv
+from tianshou.policy import BasePolicy
+
+
+class PackCollector(Collector):
+ """Collector enables the policy to interact with different types of envs with \
+ exact number of steps or episodes.
+
+ :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
+ :param env: a ``gym.Env`` environment or an instance of the
+ :class:`~tianshou.env.BaseVectorEnv` class.
+ :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class.
+ If set to None, it will not store the data. Default to None.
+ :param function preprocess_fn: a function called before the data has been added to
+ the buffer, see issue #42 and :ref:`preprocess_fn`. Default to None.
+ :param bool exploration_noise: determine whether the action needs to be modified
+ with corresponding policy's exploration noise. If so, "policy.
+ exploration_noise(act, batch)" will be called automatically to add the
+ exploration noise into action. Default to False.
+
+ The "preprocess_fn" is a function called before the data has been added to the
+ buffer with batch format. It will receive only "obs" and "env_id" when the
+ collector resets the environment, and will receive the keys "obs_next", "rew",
+ "terminated", "truncated, "info", "policy" and "env_id" in a normal env step.
+ Alternatively, it may also accept the keys "obs_next", "rew", "done", "info",
+ "policy" and "env_id".
+ It returns either a dict or a :class:`~tianshou.data.Batch` with the modified
+ keys and values. Examples are in "test/base/test_collector.py".
+
+ .. note::
+
+ Please make sure the given environment has a time limitation if using n_episode
+ collect option.
+
+ .. note::
+
+ In past versions of Tianshou, the replay buffer that was passed to `__init__`
+ was automatically reset. This is not done in the current implementation.
+ """
+
+ def __init__(
+ self,
+ policy: BasePolicy,
+ env: Union[gym.Env, BaseVectorEnv],
+ buffer: Optional[ReplayBuffer] = None,
+ preprocess_fn: Optional[Callable[..., Batch]] = None,
+ exploration_noise: bool = False,
+ ) -> None:
+ super().__init__(policy, env, buffer, preprocess_fn, exploration_noise)
+
+ def collect(
+ self,
+ n_step: Optional[int] = None,
+ n_episode: Optional[int] = None,
+ random: bool = False,
+ render: Optional[float] = None,
+ no_grad: bool = True,
+ gym_reset_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Dict[str, Any]:
+ """Collect a specified number of step or episode.
+
+ To ensure unbiased sampling result with n_episode option, this function will
+ first collect ``n_episode - env_num`` episodes, then for the last ``env_num``
+ episodes, they will be collected evenly from each env.
+
+ :param int n_step: how many steps you want to collect.
+ :param int n_episode: how many episodes you want to collect.
+ :param bool random: whether to use random policy for collecting data. Default
+ to False.
+ :param float render: the sleep time between rendering consecutive frames.
+ Default to None (no rendering).
+ :param bool no_grad: whether to retain gradient in policy.forward(). Default to
+ True (no gradient retaining).
+ :param gym_reset_kwargs: extra keyword arguments to pass into the environment's
+ reset function. Defaults to None (extra keyword arguments)
+
+ .. note::
+
+ One and only one collection number specification is permitted, either
+ ``n_step`` or ``n_episode``.
+
+ :return: A dict including the following keys
+
+ * ``n/ep`` collected number of episodes.
+ * ``n/st`` collected number of steps.
+ * ``rews`` array of episode reward over collected episodes.
+ * ``lens`` array of episode length over collected episodes.
+ * ``idxs`` array of episode start index in buffer over collected episodes.
+ * ``rew`` mean of episodic rewards.
+ * ``len`` mean of episodic lengths.
+ * ``rew_std`` standard error of episodic rewards.
+ * ``len_std`` standard error of episodic lengths.
+ """
+ assert not self.env.is_async, "Please use AsyncCollector if using async venv."
+ if n_step is not None:
+ assert n_episode is None, (
+ f"Only one of n_step or n_episode is allowed in Collector."
+ f"collect, got n_step={n_step}, n_episode={n_episode}."
+ )
+ assert n_step > 0
+ if not n_step % self.env_num == 0:
+ warnings.warn(
+ f"n_step={n_step} is not a multiple of #env ({self.env_num}), "
+ "which may cause extra transitions collected into the buffer."
+ )
+ ready_env_ids = np.arange(self.env_num)
+ elif n_episode is not None:
+ assert n_episode > 0
+ ready_env_ids = np.arange(min(self.env_num, n_episode))
+ self.data = self.data[:min(self.env_num, n_episode)]
+ else:
+ raise TypeError(
+ "Please specify at least one (either n_step or n_episode) "
+ "in AsyncCollector.collect()."
+ )
+
+ start_time = time.time()
+
+ step_count = 0
+ episode_count = 0
+ episode_rews = []
+ episode_lens = []
+ episode_start_indices = []
+ episode_ratios = []
+ episode_nums = []
+
+ while True:
+ assert len(self.data) == len(ready_env_ids)
+ # restore the state: if the last state is None, it won't store
+ last_state = self.data.policy.pop("hidden_state", None)
+
+ # get the next action
+ if random:
+ try:
+ act_sample = [
+ self._action_space[i].sample() for i in ready_env_ids
+ ]
+ except TypeError: # envpool's action space is not for per-env
+ act_sample = [self._action_space.sample() for _ in ready_env_ids]
+ act_sample = self.policy.map_action_inverse(act_sample) # type: ignore
+ self.data.update(act=act_sample)
+ else:
+ if no_grad:
+ with torch.no_grad(): # faster than retain_grad version
+ # self.data.obs will be used by agent to get result
+ result = self.policy(self.data, last_state)
+ else:
+ result = self.policy(self.data, last_state)
+ # update state / act / policy into self.data
+ policy = result.get("policy", Batch())
+ assert isinstance(policy, Batch)
+ state = result.get("state", None)
+ if state is not None:
+ policy.hidden_state = state # save state into buffer
+ act = to_numpy(result.act)
+ if self.exploration_noise:
+ act = self.policy.exploration_noise(act, self.data)
+ self.data.update(policy=policy, act=act)
+
+ # get bounded and remapped actions first (not saved into buffer)
+ action_remap = self.policy.map_action(self.data.act)
+ # step in env
+ obs_next, rew, terminated, truncated, info = self.env.step(
+ action_remap, # type: ignore
+ ready_env_ids
+ )
+ done = np.logical_or(terminated, truncated)
+
+ self.data.update(
+ obs_next=obs_next,
+ rew=rew,
+ terminated=terminated,
+ truncated=truncated,
+ done=done,
+ info=info
+ )
+ if self.preprocess_fn:
+ self.data.update(
+ self.preprocess_fn(
+ obs_next=self.data.obs_next,
+ rew=self.data.rew,
+ done=self.data.done,
+ info=self.data.info,
+ policy=self.data.policy,
+ env_id=ready_env_ids,
+ act=self.data.act,
+ )
+ )
+
+ if render:
+ self.env.render()
+ if render > 0 and not np.isclose(render, 0):
+ time.sleep(render)
+
+ # add data into the buffer
+ ptr, ep_rew, ep_len, ep_idx = self.buffer.add(
+ self.data, buffer_ids=ready_env_ids
+ )
+
+ # collect statistics
+ step_count += len(ready_env_ids)
+
+ if np.any(done):
+ env_ind_local = np.where(done)[0]
+ env_ind_global = ready_env_ids[env_ind_local]
+ episode_count += len(env_ind_local)
+ episode_lens.append(ep_len[env_ind_local])
+ episode_rews.append(ep_rew[env_ind_local])
+ episode_start_indices.append(ep_idx[env_ind_local])
+
+ episode_ratios.append(self.data.info['ratio'][env_ind_local])
+ episode_nums.append(self.data.info['counter'][env_ind_local])
+ # now we copy obs_next to obs, but since there might be
+ # finished episodes, we have to reset finished envs first.
+ self._reset_env_with_ids(
+ env_ind_local, env_ind_global, gym_reset_kwargs
+ )
+ for i in env_ind_local:
+ self._reset_state(i)
+
+ # remove surplus env id from ready_env_ids
+ # to avoid bias in selecting environments
+ if n_episode:
+ surplus_env_num = len(ready_env_ids) - (n_episode - episode_count)
+ if surplus_env_num > 0:
+ mask = np.ones_like(ready_env_ids, dtype=bool)
+ mask[env_ind_local[:surplus_env_num]] = False
+ ready_env_ids = ready_env_ids[mask]
+ self.data = self.data[mask]
+
+ self.data.obs = self.data.obs_next
+
+ if (n_step and step_count >= n_step) or \
+ (n_episode and episode_count >= n_episode):
+ break
+
+ # generate statistics
+ self.collect_step += step_count
+ self.collect_episode += episode_count
+ self.collect_time += max(time.time() - start_time, 1e-9)
+
+ if n_episode:
+ self.data = Batch(
+ obs={},
+ act={},
+ rew={},
+ terminated={},
+ truncated={},
+ done={},
+ obs_next={},
+ info={},
+ policy={}
+ )
+ self.reset_env()
+
+ if episode_count > 0:
+ rews, lens, idxs, ratios, nums = list(
+ map(
+ np.concatenate,
+ [episode_rews, episode_lens, episode_start_indices, episode_ratios, episode_nums]
+ )
+ )
+ rew_mean, rew_std = rews.mean(), rews.std()
+ len_mean, len_std = lens.mean(), lens.std()
+ ratio_mean, ratio_std = ratios.mean(), ratios.std()
+ num_mean, num_std = nums.mean(), nums.std()
+ else:
+ rews, lens, idxs = np.array([]), np.array([], int), np.array([], int)
+ rew_mean = rew_std = len_mean = len_std = 0
+ ratios = np.array([])
+ nums = np.array([])
+ ratio_mean = ratio_std = 0
+ num_mean = num_std = 0
+
+ return {
+ "n/ep": episode_count,
+ "n/st": step_count,
+ "rews": rews,
+ "lens": lens,
+ "idxs": idxs,
+ "rew": rew_mean,
+ "len": len_mean,
+ "rew_std": rew_std,
+ "len_std": len_std,
+ "ratios": ratios,
+ "ratio": ratio_mean,
+ "ratio_std": ratio_std,
+ "nums": nums,
+ "num": num_mean,
+ "num_std": num_std,
+ }
diff --git a/render.py b/render.py
new file mode 100644
index 0000000..60f644e
--- /dev/null
+++ b/render.py
@@ -0,0 +1,224 @@
+from typing import Any, Dict, List, Optional, Type, Union
+import os
+import time
+import datetime
+
+import vtk
+# import vtkmodules.all as vtk
+
+
+vtk_color = {
+ 'Whites': ['antique_white', 'azure', 'bisque', 'blanched_almond',
+ 'cornsilk', 'eggshell', 'floral_white', 'gainsboro',
+ 'ghost_white', 'honeydew', 'ivory', 'lavender',
+ 'lavender_blush', 'lemon_chiffon', 'linen', 'mint_cream',
+ 'misty_rose', 'moccasin', 'navajo_white', 'old_lace',
+ 'papaya_whip', 'peach_puff', 'seashell', 'snow',
+ 'thistle', 'titanium_white', 'wheat', 'white',
+ 'white_smoke', 'zinc_white'],
+ 'Greys': ['cold_grey', 'dim_grey', 'grey', 'light_grey',
+ 'slate_grey', 'slate_grey_dark', 'slate_grey_light',
+ 'warm_grey'],
+ 'Reds': ['coral', 'coral_light',
+ 'hot_pink', 'light_salmon',
+ 'pink', 'pink_light',
+ 'raspberry', 'rose_madder', 'salmon',
+ ],
+ # 'Browns': ['beige', 'brown', 'brown_madder', 'brown_ochre',
+ # 'burlywood', 'burnt_sienna', 'burnt_umber', 'chocolate',
+ # 'flesh', 'flesh_ochre', 'gold_ochre',
+ # 'greenish_umber', 'khaki', 'khaki_dark', 'light_beige',
+ # 'peru', 'rosy_brown', 'raw_sienna', 'raw_umber', 'sepia',
+ # 'sienna', 'saddle_brown', 'sandy_brown', 'tan',
+ # 'van_dyke_brown'],
+ 'Oranges': ['cadmium_orange', 'cadmium_red_light', 'carrot',
+ 'dark_orange', 'mars_orange', 'mars_yellow', 'orange',
+ 'orange_red', 'yellow_ochre'],
+ 'Yellows': ['aureoline_yellow', 'banana', 'cadmium_lemon',
+ 'cadmium_yellow', 'cadmium_yellow_light', 'gold',
+ 'goldenrod', 'goldenrod_dark', 'goldenrod_light',
+ 'goldenrod_pale', 'light_goldenrod', 'melon',
+ 'yellow', 'yellow_light'],
+ 'Greens': ['chartreuse', 'chrome_oxide_green', 'cinnabar_green',
+ 'cobalt_green', 'emerald_green', 'forest_green',
+ 'green_dark', 'green_pale', 'green_yellow', 'lawn_green',
+ 'lime_green', 'mint', 'olive', 'olive_drab',
+ 'olive_green_dark', 'permanent_green', 'sap_green',
+ 'sea_green', 'sea_green_dark', 'sea_green_medium',
+ 'sea_green_light', 'spring_green', 'spring_green_medium',
+ 'terre_verte', 'viridian_light', 'yellow_green'],
+ 'Cyans': ['aquamarine', 'aquamarine_medium', 'cyan', 'cyan_white',
+ 'turquoise', 'turquoise_dark', 'turquoise_medium',
+ 'turquoise_pale'],
+ 'Blues': ['alice_blue', 'blue_light', 'blue_medium',
+ 'cadet', 'cobalt', 'cornflower', 'cerulean', 'dodger_blue',
+ 'indigo', 'manganese_blue', 'midnight_blue', 'navy',
+ 'peacock', 'powder_blue', 'royal_blue', 'slate_blue',
+ 'slate_blue_dark', 'slate_blue_light',
+ 'slate_blue_medium', 'sky_blue',
+ 'sky_blue_light', 'steel_blue', 'steel_blue_light',
+ 'turquoise_blue', 'ultramarine'],
+ 'Magentas': ['blue_violet', 'magenta',
+ 'orchid', 'orchid_dark', 'orchid_medium',
+ 'plum', 'purple',
+ 'purple_medium', 'ultramarine_violet', 'violet',
+ 'violet_dark', 'violet_red_medium',
+ 'violet_red_pale']
+}
+color_key = list(vtk_color.keys())
+
+
+class VTKRender:
+ def __init__(
+ self,
+ container_size: List[int],
+ win_size: List[int]=[600, 600],
+ offscreen: bool=True,
+ auto_render: bool=True
+ ) -> None:
+ self.container_size = container_size
+ self.item_idx = 0
+ self.auto_render = auto_render
+
+ # 1. render
+ self.render = vtk.vtkRenderer()
+ self.render.SetBackground(1.0, 1.0, 1.0)
+
+ # 2. render window
+ self.render_window = vtk.vtkRenderWindow()
+ # if offscreen:
+ # self.render_window.SetOffScreenRendering(1)
+ self.render_window.SetWindowName("Packing Visualization")
+ self.render_window.SetSize(win_size[0], win_size[1])
+ self.render_window.AddRenderer(self.render)
+
+ # 3. interactor
+ self.interactor = vtk.vtkRenderWindowInteractor()
+ self.interactor.SetRenderWindow(self.render_window)
+
+ # 4. camera
+ self.camera = vtk.vtkCamera()
+ self.camera.SetPosition(
+ 2.5 * max(self.container_size),
+ -2 * max(self.container_size),
+ 2 * max(self.container_size)
+ )
+ self.camera.SetViewUp(0, 0, 1.5 * max(container_size))
+ self.render.SetActiveCamera(self.camera)
+
+ # 5. axes
+ self._init_axes()
+
+ # 6. container (cube)
+ self._init_container()
+
+ self.interactor.Initialize()
+ self.render_window.Render()
+ time.sleep(0.5)
+
+ def _init_axes(self) -> None:
+ axes = vtk.vtkAxesActor()
+
+ transform = vtk.vtkTransform()
+ transform.Translate(
+ -0.5 * self.container_size[0],
+ -0.5 * self.container_size[1],
+ -0.5 * self.container_size[2]
+ )
+
+ axes.SetUserTransform(transform)
+
+ sigma = 0.1
+ axes_l_x = self.container_size[0] + sigma * self.container_size[2]
+ axes_l_y = self.container_size[1] + sigma * self.container_size[2]
+ axes_l_z = (1 + sigma) * self.container_size[2]
+
+ axes.SetTotalLength(axes_l_x, axes_l_y, axes_l_z)
+ axes.SetNormalizedShaftLength(1, 1, 1)
+ axes.SetNormalizedTipLength(0.05, 0.05, 0.05)
+ axes.AxisLabelsOff()
+
+ self.render.AddActor(axes)
+
+ def _init_container(self) -> None:
+ container = vtk.vtkCubeSource()
+ container.SetXLength(self.container_size[0])
+ container.SetYLength(self.container_size[1])
+ container.SetZLength(self.container_size[2])
+ container.SetCenter([0, 0, 0])
+
+ mapper = vtk.vtkPolyDataMapper()
+ mapper.SetInputConnection(container.GetOutputPort())
+
+ actor = vtk.vtkActor()
+ actor.SetMapper(mapper)
+ actor.GetProperty().SetColor(1.0, 1.0, 1.0)
+ actor.GetProperty().SetRepresentationToWireframe()
+
+ self.render.AddActor(actor)
+
+ def add_item(self, item_size: List[int], item_pos: List[int], dir: str="") -> None:
+
+ item = vtk.vtkCubeSource()
+ item.SetXLength(item_size[0])
+ item.SetYLength(item_size[1])
+ item.SetZLength(item_size[2])
+ item.SetCenter([
+ -0.5 * self.container_size[0] + 0.5 * item_size[0] + item_pos[0],
+ -0.5 * self.container_size[1] + 0.5 * item_size[1] + item_pos[1],
+ -0.5 * self.container_size[2] + 0.5 * item_size[2] + item_pos[2]
+ ])
+
+ mapper = vtk.vtkPolyDataMapper()
+ mapper.SetInputConnection(item.GetOutputPort())
+
+ colors = vtk.vtkNamedColors()
+ color_0 = color_key[self.item_idx % len(color_key)]
+ color_1 = int(self.item_idx / len(color_key))
+
+ actor = vtk.vtkActor()
+ actor.SetMapper(mapper)
+ actor.GetProperty().SetColor(colors.GetColor3d("red"))
+ actor.GetProperty().EdgeVisibilityOn()
+
+ self.render.AddActor(actor)
+ time.sleep(0.5)
+ self.render_window.Render()
+
+ time.sleep(0.3)
+ actor.GetProperty().SetColor(colors.GetColor3d(vtk_color[color_0][color_1]))
+ self.render_window.Render()
+
+ self.item_idx += 1
+
+ if not self.auto_render:
+ self.hold_on()
+
+ def hold_on(self) -> None:
+ self.interactor.Start()
+
+ def save_img(self) -> None:
+ time_str = datetime.datetime.now().strftime("%Y.%m.%d-%H-%M-%S.%f")
+ img_name = time_str + r".png"
+ path = os.path.join("images", "tmp")
+ if not os.path.exists(path):
+ os.makedirs(path)
+
+ window_to_image_filter = vtk.vtkWindowToImageFilter()
+ window_to_image_filter.SetInput(self.render_window)
+ window_to_image_filter.Update()
+
+ writer = vtk.vtkPNGWriter()
+ writer.SetFileName(os.path.join(path, img_name))
+ writer.SetInputConnection(window_to_image_filter.GetOutputPort())
+ writer.Write()
+
+
+if __name__ == "__main__":
+ render = VTKRender([10, 10, 10])
+
+ render.add_item([2, 3, 2], [0, 0, 0])
+ render.hold_on()
+ render.add_item([1, 1, 1], [2, 0, 0])
+
+ render.hold_on()
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..5c47ab7
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,10 @@
+vtk==9.0.2
+scipy==1.11.3
+matplotlib==3.8.0
+gymnasium==0.29.1
+numpy>1.16.0
+tensorboard>=2.5.0
+numba>=0.51.0
+h5py>=2.10.0
+tqdm
+omegaconf
\ No newline at end of file
diff --git a/tianshou/__init__.py b/tianshou/__init__.py
new file mode 100644
index 0000000..e039d6e
--- /dev/null
+++ b/tianshou/__init__.py
@@ -0,0 +1,12 @@
+from tianshou import data, env, exploration, policy, trainer, utils
+
+__version__ = "0.5.0"
+
+__all__ = [
+ "env",
+ "data",
+ "utils",
+ "policy",
+ "trainer",
+ "exploration",
+]
diff --git a/tianshou/data/__init__.py b/tianshou/data/__init__.py
new file mode 100644
index 0000000..7a86ce8
--- /dev/null
+++ b/tianshou/data/__init__.py
@@ -0,0 +1,41 @@
+"""Data package."""
+# isort:skip_file
+
+from tianshou.data.batch import Batch
+from tianshou.data.utils.converter import to_numpy, to_torch, to_torch_as
+from tianshou.data.utils.segtree import SegmentTree
+from tianshou.data.buffer.base import ReplayBuffer
+from tianshou.data.buffer.prio import PrioritizedReplayBuffer
+from tianshou.data.buffer.her import HERReplayBuffer
+from tianshou.data.buffer.manager import (
+ ReplayBufferManager,
+ PrioritizedReplayBufferManager,
+ HERReplayBufferManager,
+)
+from tianshou.data.buffer.vecbuf import (
+ HERVectorReplayBuffer,
+ PrioritizedVectorReplayBuffer,
+ VectorReplayBuffer,
+)
+from tianshou.data.buffer.cached import CachedReplayBuffer
+from tianshou.data.collector import Collector, AsyncCollector
+
+__all__ = [
+ "Batch",
+ "to_numpy",
+ "to_torch",
+ "to_torch_as",
+ "SegmentTree",
+ "ReplayBuffer",
+ "PrioritizedReplayBuffer",
+ "HERReplayBuffer",
+ "ReplayBufferManager",
+ "PrioritizedReplayBufferManager",
+ "HERReplayBufferManager",
+ "VectorReplayBuffer",
+ "PrioritizedVectorReplayBuffer",
+ "HERVectorReplayBuffer",
+ "CachedReplayBuffer",
+ "Collector",
+ "AsyncCollector",
+]
diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py
new file mode 100644
index 0000000..e204151
--- /dev/null
+++ b/tianshou/data/batch.py
@@ -0,0 +1,767 @@
+import pprint
+import warnings
+from collections.abc import Collection
+from copy import deepcopy
+from numbers import Number
+from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Union
+
+import numpy as np
+import torch
+
+IndexType = Union[slice, int, np.ndarray, List[int]]
+
+
+def _is_batch_set(obj: Any) -> bool:
+ # Batch set is a list/tuple of dict/Batch objects,
+ # or 1-D np.ndarray with object type,
+ # where each element is a dict/Batch object
+ if isinstance(obj, np.ndarray): # most often case
+ # "for element in obj" will just unpack the first dimension,
+ # but obj.tolist() will flatten ndarray of objects
+ # so do not use obj.tolist()
+ if obj.shape == ():
+ return False
+ return obj.dtype == object and \
+ all(isinstance(element, (dict, Batch)) for element in obj)
+ elif isinstance(obj, (list, tuple)):
+ if len(obj) > 0 and all(isinstance(element, (dict, Batch)) for element in obj):
+ return True
+ return False
+
+
+def _is_scalar(value: Any) -> bool:
+ # check if the value is a scalar
+ # 1. python bool object, number object: isinstance(value, Number)
+ # 2. numpy scalar: isinstance(value, np.generic)
+ # 3. python object rather than dict / Batch / tensor
+ # the check of dict / Batch is omitted because this only checks a value.
+ # a dict / Batch will eventually check their values
+ if isinstance(value, torch.Tensor):
+ return value.numel() == 1 and not value.shape
+ else:
+ # np.asanyarray will cause dead loop in some cases
+ return np.isscalar(value)
+
+
+def _is_number(value: Any) -> bool:
+ # isinstance(value, Number) checks 1, 1.0, np.int(1), np.float(1.0), etc.
+ # isinstance(value, np.nummber) checks np.int32(1), np.float64(1.0), etc.
+ # isinstance(value, np.bool_) checks np.bool_(True), etc.
+ # similar to np.isscalar but np.isscalar('st') returns True
+ return isinstance(value, (Number, np.number, np.bool_))
+
+
+def _to_array_with_correct_type(obj: Any) -> np.ndarray:
+ if isinstance(obj, np.ndarray) and \
+ issubclass(obj.dtype.type, (np.bool_, np.number)):
+ return obj # most often case
+ # convert the value to np.ndarray
+ # convert to object obj type if neither bool nor number
+ # raises an exception if array's elements are tensors themselves
+ obj_array = np.asanyarray(obj)
+ if not issubclass(obj_array.dtype.type, (np.bool_, np.number)):
+ obj_array = obj_array.astype(object)
+ if obj_array.dtype == object:
+ # scalar ndarray with object obj type is very annoying
+ # a=np.array([np.array({}, dtype=object), np.array({}, dtype=object)])
+ # a is not array([{}, {}], dtype=object), and a[0]={} results in
+ # something very strange:
+ # array([{}, array({}, dtype=object)], dtype=object)
+ if not obj_array.shape:
+ obj_array = obj_array.item(0)
+ elif all(isinstance(arr, np.ndarray) for arr in obj_array.reshape(-1)):
+ return obj_array # various length, np.array([[1], [2, 3], [4, 5, 6]])
+ elif any(isinstance(arr, torch.Tensor) for arr in obj_array.reshape(-1)):
+ raise ValueError("Numpy arrays of tensors are not supported yet.")
+ return obj_array
+
+
+def _create_value(
+ inst: Any,
+ size: int,
+ stack: bool = True,
+) -> Union["Batch", np.ndarray, torch.Tensor]:
+ """Create empty place-holders accroding to inst's shape.
+
+ :param bool stack: whether to stack or to concatenate. E.g. if inst has shape of
+ (3, 5), size = 10, stack=True returns an np.ndarry with shape of (10, 3, 5),
+ otherwise (10, 5)
+ """
+ has_shape = isinstance(inst, (np.ndarray, torch.Tensor))
+ is_scalar = _is_scalar(inst)
+ if not stack and is_scalar:
+ # should never hit since it has already checked in Batch.cat_ , here we do not
+ # consider scalar types, following the behavior of numpy which does not support
+ # concatenation of zero-dimensional arrays (scalars)
+ raise TypeError(f"cannot concatenate with {inst} which is scalar")
+ if has_shape:
+ shape = (size, *inst.shape) if stack else (size, *inst.shape[1:])
+ if isinstance(inst, np.ndarray):
+ target_type = inst.dtype.type if issubclass(
+ inst.dtype.type, (np.bool_, np.number)
+ ) else object
+ return np.full(
+ shape, fill_value=None if target_type == object else 0, dtype=target_type
+ )
+ elif isinstance(inst, torch.Tensor):
+ return torch.full(shape, fill_value=0, device=inst.device, dtype=inst.dtype)
+ elif isinstance(inst, (dict, Batch)):
+ zero_batch = Batch()
+ for key, val in inst.items():
+ zero_batch.__dict__[key] = _create_value(val, size, stack=stack)
+ return zero_batch
+ elif is_scalar:
+ return _create_value(np.asarray(inst), size, stack=stack)
+ else: # fall back to object
+ return np.array([None for _ in range(size)], object)
+
+
+def _assert_type_keys(keys: Iterable[str]) -> None:
+ assert all(isinstance(key, str) for key in keys), \
+ f"keys should all be string, but got {keys}"
+
+
+def _parse_value(obj: Any) -> Optional[Union["Batch", np.ndarray, torch.Tensor]]:
+ if isinstance(obj, Batch): # most often case
+ return obj
+ elif (isinstance(obj, np.ndarray) and
+ issubclass(obj.dtype.type, (np.bool_, np.number))) or \
+ isinstance(obj, torch.Tensor) or obj is None: # third often case
+ return obj
+ elif _is_number(obj): # second often case, but it is more time-consuming
+ return np.asanyarray(obj)
+ elif isinstance(obj, dict):
+ return Batch(obj)
+ else:
+ if not isinstance(obj, np.ndarray) and \
+ isinstance(obj, Collection) and len(obj) > 0 and \
+ all(isinstance(element, torch.Tensor) for element in obj):
+ try:
+ return torch.stack(obj) # type: ignore
+ except RuntimeError as exception:
+ raise TypeError(
+ "Batch does not support non-stackable iterable"
+ " of torch.Tensor as unique value yet."
+ ) from exception
+ if _is_batch_set(obj):
+ obj = Batch(obj) # list of dict / Batch
+ else:
+ # None, scalar, normal obj list (main case)
+ # or an actual list of objects
+ try:
+ obj = _to_array_with_correct_type(obj)
+ except ValueError as exception:
+ raise TypeError(
+ "Batch does not support heterogeneous list/"
+ "tuple of tensors as unique value yet."
+ ) from exception
+ return obj
+
+
+def _alloc_by_keys_diff(
+ meta: "Batch", batch: "Batch", size: int, stack: bool = True
+) -> None:
+ for key in batch.keys():
+ if key in meta.keys():
+ if isinstance(meta[key], Batch) and isinstance(batch[key], Batch):
+ _alloc_by_keys_diff(meta[key], batch[key], size, stack)
+ elif isinstance(meta[key], Batch) and meta[key].is_empty():
+ meta[key] = _create_value(batch[key], size, stack)
+ else:
+ meta[key] = _create_value(batch[key], size, stack)
+
+
+class Batch:
+ """The internal data structure in Tianshou.
+
+ Batch is a kind of supercharged array (of temporal data) stored individually in a
+ (recursive) dictionary of object that can be either numpy array, torch tensor, or
+ batch themselves. It is designed to make it extremely easily to access, manipulate
+ and set partial view of the heterogeneous data conveniently.
+
+ For a detailed description, please refer to :ref:`batch_concept`.
+ """
+
+ def __init__(
+ self,
+ batch_dict: Optional[Union[dict, "Batch", Sequence[Union[dict, "Batch"]],
+ np.ndarray]] = None,
+ copy: bool = False,
+ **kwargs: Any,
+ ) -> None:
+ if copy:
+ batch_dict = deepcopy(batch_dict)
+ if batch_dict is not None:
+ if isinstance(batch_dict, (dict, Batch)):
+ _assert_type_keys(batch_dict.keys())
+ for batch_key, obj in batch_dict.items():
+ self.__dict__[batch_key] = _parse_value(obj)
+ elif _is_batch_set(batch_dict):
+ self.stack_(batch_dict) # type: ignore
+ if len(kwargs) > 0:
+ self.__init__(kwargs, copy=copy) # type: ignore
+
+ def __setattr__(self, key: str, value: Any) -> None:
+ """Set self.key = value."""
+ self.__dict__[key] = _parse_value(value)
+
+ def __getattr__(self, key: str) -> Any:
+ """Return self.key. The "Any" return type is needed for mypy."""
+ return getattr(self.__dict__, key)
+
+ def __contains__(self, key: str) -> bool:
+ """Return key in self."""
+ return key in self.__dict__
+
+ def __getstate__(self) -> Dict[str, Any]:
+ """Pickling interface.
+
+ Only the actual data are serialized for both efficiency and simplicity.
+ """
+ state = {}
+ for batch_key, obj in self.items():
+ if isinstance(obj, Batch):
+ obj = obj.__getstate__()
+ state[batch_key] = obj
+ return state
+
+ def __setstate__(self, state: Dict[str, Any]) -> None:
+ """Unpickling interface.
+
+ At this point, self is an empty Batch instance that has not been
+ initialized, so it can safely be initialized by the pickle state.
+ """
+ self.__init__(**state) # type: ignore
+
+ def __getitem__(self, index: Union[str, IndexType]) -> Any:
+ """Return self[index]."""
+ if isinstance(index, str):
+ return self.__dict__[index]
+ batch_items = self.items()
+ if len(batch_items) > 0:
+ new_batch = Batch()
+ for batch_key, obj in batch_items:
+ if isinstance(obj, Batch) and obj.is_empty():
+ new_batch.__dict__[batch_key] = Batch()
+ else:
+ new_batch.__dict__[batch_key] = obj[index]
+ return new_batch
+ else:
+ raise IndexError("Cannot access item from empty Batch object.")
+
+ def __setitem__(self, index: Union[str, IndexType], value: Any) -> None:
+ """Assign value to self[index]."""
+ value = _parse_value(value)
+ if isinstance(index, str):
+ self.__dict__[index] = value
+ return
+ if not isinstance(value, Batch):
+ raise ValueError(
+ "Batch does not supported tensor assignment. "
+ "Use a compatible Batch or dict instead."
+ )
+ if not set(value.keys()).issubset(self.__dict__.keys()):
+ raise ValueError("Creating keys is not supported by item assignment.")
+ for key, val in self.items():
+ try:
+ self.__dict__[key][index] = value[key]
+ except KeyError:
+ if isinstance(val, Batch):
+ self.__dict__[key][index] = Batch()
+ elif isinstance(val, torch.Tensor) or \
+ (isinstance(val, np.ndarray) and
+ issubclass(val.dtype.type, (np.bool_, np.number))):
+ self.__dict__[key][index] = 0
+ else:
+ self.__dict__[key][index] = None
+
+ def __iadd__(self, other: Union["Batch", Number, np.number]) -> "Batch":
+ """Algebraic addition with another Batch instance in-place."""
+ if isinstance(other, Batch):
+ for (batch_key, obj), value in zip(
+ self.__dict__.items(), other.__dict__.values()
+ ): # TODO are keys consistent?
+ if isinstance(obj, Batch) and obj.is_empty():
+ continue
+ else:
+ self.__dict__[batch_key] += value
+ return self
+ elif _is_number(other):
+ for batch_key, obj in self.items():
+ if isinstance(obj, Batch) and obj.is_empty():
+ continue
+ else:
+ self.__dict__[batch_key] += other
+ return self
+ else:
+ raise TypeError("Only addition of Batch or number is supported.")
+
+ def __add__(self, other: Union["Batch", Number, np.number]) -> "Batch":
+ """Algebraic addition with another Batch instance out-of-place."""
+ return deepcopy(self).__iadd__(other)
+
+ def __imul__(self, value: Union[Number, np.number]) -> "Batch":
+ """Algebraic multiplication with a scalar value in-place."""
+ assert _is_number(value), "Only multiplication by a number is supported."
+ for batch_key, obj in self.__dict__.items():
+ if isinstance(obj, Batch) and obj.is_empty():
+ continue
+ self.__dict__[batch_key] *= value
+ return self
+
+ def __mul__(self, value: Union[Number, np.number]) -> "Batch":
+ """Algebraic multiplication with a scalar value out-of-place."""
+ return deepcopy(self).__imul__(value)
+
+ def __itruediv__(self, value: Union[Number, np.number]) -> "Batch":
+ """Algebraic division with a scalar value in-place."""
+ assert _is_number(value), "Only division by a number is supported."
+ for batch_key, obj in self.__dict__.items():
+ if isinstance(obj, Batch) and obj.is_empty():
+ continue
+ self.__dict__[batch_key] /= value
+ return self
+
+ def __truediv__(self, value: Union[Number, np.number]) -> "Batch":
+ """Algebraic division with a scalar value out-of-place."""
+ return deepcopy(self).__itruediv__(value)
+
+ def __repr__(self) -> str:
+ """Return str(self)."""
+ self_str = self.__class__.__name__ + "(\n"
+ flag = False
+ for batch_key, obj in self.__dict__.items():
+ rpl = "\n" + " " * (6 + len(batch_key))
+ obj_name = pprint.pformat(obj).replace("\n", rpl)
+ self_str += f" {batch_key}: {obj_name},\n"
+ flag = True
+ if flag:
+ self_str += ")"
+ else:
+ self_str = self.__class__.__name__ + "()"
+ return self_str
+
+ def to_numpy(self) -> None:
+ """Change all torch.Tensor to numpy.ndarray in-place."""
+ for batch_key, obj in self.items():
+ if isinstance(obj, torch.Tensor):
+ self.__dict__[batch_key] = obj.detach().cpu().numpy()
+ elif isinstance(obj, Batch):
+ obj.to_numpy()
+
+ def to_torch(
+ self,
+ dtype: Optional[torch.dtype] = None,
+ device: Union[str, int, torch.device] = "cpu",
+ ) -> None:
+ """Change all numpy.ndarray to torch.Tensor in-place."""
+ if not isinstance(device, torch.device):
+ device = torch.device(device)
+
+ for batch_key, obj in self.items():
+ if isinstance(obj, torch.Tensor):
+ if dtype is not None and obj.dtype != dtype or \
+ obj.device.type != device.type or \
+ device.index != obj.device.index:
+ if dtype is not None:
+ obj = obj.type(dtype)
+ self.__dict__[batch_key] = obj.to(device)
+ elif isinstance(obj, Batch):
+ obj.to_torch(dtype, device)
+ else:
+ # ndarray or scalar
+ if not isinstance(obj, np.ndarray):
+ obj = np.asanyarray(obj)
+ obj = torch.from_numpy(obj).to(device)
+ if dtype is not None:
+ obj = obj.type(dtype)
+ self.__dict__[batch_key] = obj
+
+ def __cat(self, batches: Sequence[Union[dict, "Batch"]], lens: List[int]) -> None:
+ """Private method for Batch.cat_.
+
+ ::
+
+ >>> a = Batch(a=np.random.randn(3, 4))
+ >>> x = Batch(a=a, b=np.random.randn(4, 4))
+ >>> y = Batch(a=Batch(a=Batch()), b=np.random.randn(4, 4))
+
+ If we want to concatenate x and y, we want to pad y.a.a with zeros.
+ Without ``lens`` as a hint, when we concatenate x.a and y.a, we would
+ not be able to know how to pad y.a. So ``Batch.cat_`` should compute
+ the ``lens`` to give ``Batch.__cat`` a hint.
+ ::
+
+ >>> ans = Batch.cat([x, y])
+ >>> # this is equivalent to the following line
+ >>> ans = Batch(); ans.__cat([x, y], lens=[3, 4])
+ >>> # this lens is equal to [len(a), len(b)]
+ """
+ # partial keys will be padded by zeros
+ # with the shape of [len, rest_shape]
+ sum_lens = [0]
+ for len_ in lens:
+ sum_lens.append(sum_lens[-1] + len_)
+ # collect non-empty keys
+ keys_map = [
+ set(
+ batch_key for batch_key, obj in batch.items()
+ if not (isinstance(obj, Batch) and obj.is_empty())
+ ) for batch in batches
+ ]
+ keys_shared = set.intersection(*keys_map)
+ values_shared = [[batch[key] for batch in batches] for key in keys_shared]
+ for key, shared_value in zip(keys_shared, values_shared):
+ if all(isinstance(element, (dict, Batch)) for element in shared_value):
+ batch_holder = Batch()
+ batch_holder.__cat(shared_value, lens=lens)
+ self.__dict__[key] = batch_holder
+ elif all(isinstance(element, torch.Tensor) for element in shared_value):
+ self.__dict__[key] = torch.cat(shared_value)
+ else:
+ # cat Batch(a=np.zeros((3, 4))) and Batch(a=Batch(b=Batch()))
+ # will fail here
+ shared_value = np.concatenate(shared_value)
+ self.__dict__[key] = _to_array_with_correct_type(shared_value)
+ keys_total = set.union(*[set(batch.keys()) for batch in batches])
+ keys_reserve_or_partial = set.difference(keys_total, keys_shared)
+ # keys that are reserved in all batches
+ keys_reserve = set.difference(keys_total, set.union(*keys_map))
+ # keys that occur only in some batches, but not all
+ keys_partial = keys_reserve_or_partial.difference(keys_reserve)
+ for key in keys_reserve:
+ # reserved keys
+ self.__dict__[key] = Batch()
+ for key in keys_partial:
+ for i, batch in enumerate(batches):
+ if key not in batch.__dict__:
+ continue
+ value = batch.get(key)
+ if isinstance(value, Batch) and value.is_empty():
+ continue
+ try:
+ self.__dict__[key][sum_lens[i]:sum_lens[i + 1]] = value
+ except KeyError:
+ self.__dict__[key] = \
+ _create_value(value, sum_lens[-1], stack=False)
+ self.__dict__[key][sum_lens[i]:sum_lens[i + 1]] = value
+
+ def cat_(self, batches: Union["Batch", Sequence[Union[dict, "Batch"]]]) -> None:
+ """Concatenate a list of (or one) Batch objects into current batch."""
+ if isinstance(batches, Batch):
+ batches = [batches]
+ # check input format
+ batch_list = []
+ for batch in batches:
+ if isinstance(batch, dict):
+ if len(batch) > 0:
+ batch_list.append(Batch(batch))
+ elif isinstance(batch, Batch):
+ # x.is_empty() means that x is Batch() and should be ignored
+ if not batch.is_empty():
+ batch_list.append(batch)
+ else:
+ raise ValueError(f"Cannot concatenate {type(batch)} in Batch.cat_")
+ if len(batch_list) == 0:
+ return
+ batches = batch_list
+ try:
+ # x.is_empty(recurse=True) here means x is a nested empty batch
+ # like Batch(a=Batch), and we have to treat it as length zero and
+ # keep it.
+ lens = [
+ 0 if batch.is_empty(recurse=True) else len(batch) for batch in batches
+ ]
+ except TypeError as exception:
+ raise ValueError(
+ "Batch.cat_ meets an exception. Maybe because there is any "
+ f"scalar in {batches} but Batch.cat_ does not support the "
+ "concatenation of scalar."
+ ) from exception
+ if not self.is_empty():
+ batches = [self] + list(batches)
+ lens = [0 if self.is_empty(recurse=True) else len(self)] + lens
+ self.__cat(batches, lens)
+
+ @staticmethod
+ def cat(batches: Sequence[Union[dict, "Batch"]]) -> "Batch":
+ """Concatenate a list of Batch object into a single new batch.
+
+ For keys that are not shared across all batches, batches that do not
+ have these keys will be padded by zeros with appropriate shapes. E.g.
+ ::
+
+ >>> a = Batch(a=np.zeros([3, 4]), common=Batch(c=np.zeros([3, 5])))
+ >>> b = Batch(b=np.zeros([4, 3]), common=Batch(c=np.zeros([4, 5])))
+ >>> c = Batch.cat([a, b])
+ >>> c.a.shape
+ (7, 4)
+ >>> c.b.shape
+ (7, 3)
+ >>> c.common.c.shape
+ (7, 5)
+ """
+ batch = Batch()
+ batch.cat_(batches)
+ return batch
+
+ def stack_(self, batches: Sequence[Union[dict, "Batch"]], axis: int = 0) -> None:
+ """Stack a list of Batch object into current batch."""
+ # check input format
+ batch_list = []
+ for batch in batches:
+ if isinstance(batch, dict):
+ if len(batch) > 0:
+ batch_list.append(Batch(batch))
+ elif isinstance(batch, Batch):
+ # x.is_empty() means that x is Batch() and should be ignored
+ if not batch.is_empty():
+ batch_list.append(batch)
+ else:
+ raise ValueError(f"Cannot concatenate {type(batch)} in Batch.stack_")
+ if len(batch_list) == 0:
+ return
+ batches = batch_list
+ if not self.is_empty():
+ batches = [self] + batches
+ # collect non-empty keys
+ keys_map = [
+ set(
+ batch_key for batch_key, obj in batch.items()
+ if not (isinstance(obj, Batch) and obj.is_empty())
+ ) for batch in batches
+ ]
+ keys_shared = set.intersection(*keys_map)
+ values_shared = [[batch[key] for batch in batches] for key in keys_shared]
+ for shared_key, value in zip(keys_shared, values_shared):
+ # second often
+ if all(isinstance(element, torch.Tensor) for element in value):
+ self.__dict__[shared_key] = torch.stack(value, axis)
+ # third often
+ elif all(isinstance(element, (Batch, dict)) for element in value):
+ self.__dict__[shared_key] = Batch.stack(value, axis)
+ else: # most often case is np.ndarray
+ try:
+ self.__dict__[shared_key] = \
+ _to_array_with_correct_type(np.stack(value, axis))
+ except ValueError:
+ warnings.warn(
+ "You are using tensors with different shape,"
+ " fallback to dtype=object by default."
+ )
+ self.__dict__[shared_key] = np.array(value, dtype=object)
+ # all the keys
+ keys_total = set.union(*[set(batch.keys()) for batch in batches])
+ # keys that are reserved in all batches
+ keys_reserve = set.difference(keys_total, set.union(*keys_map))
+ # keys that are either partial or reserved
+ keys_reserve_or_partial = set.difference(keys_total, keys_shared)
+ # keys that occur only in some batches, but not all
+ keys_partial = keys_reserve_or_partial.difference(keys_reserve)
+ if keys_partial and axis != 0:
+ raise ValueError(
+ f"Stack of Batch with non-shared keys {keys_partial} is only "
+ f"supported with axis=0, but got axis={axis}!"
+ )
+ for key in keys_reserve:
+ # reserved keys
+ self.__dict__[key] = Batch()
+ for key in keys_partial:
+ for i, batch in enumerate(batches):
+ if key not in batch.__dict__:
+ continue
+ value = batch.get(key)
+ if isinstance(value, Batch) and value.is_empty(): # type: ignore
+ continue # type: ignore
+ try:
+ self.__dict__[key][i] = value
+ except KeyError:
+ self.__dict__[key] = _create_value(value, len(batches))
+ self.__dict__[key][i] = value
+
+ @staticmethod
+ def stack(batches: Sequence[Union[dict, "Batch"]], axis: int = 0) -> "Batch":
+ """Stack a list of Batch object into a single new batch.
+
+ For keys that are not shared across all batches, batches that do not
+ have these keys will be padded by zeros. E.g.
+ ::
+
+ >>> a = Batch(a=np.zeros([4, 4]), common=Batch(c=np.zeros([4, 5])))
+ >>> b = Batch(b=np.zeros([4, 6]), common=Batch(c=np.zeros([4, 5])))
+ >>> c = Batch.stack([a, b])
+ >>> c.a.shape
+ (2, 4, 4)
+ >>> c.b.shape
+ (2, 4, 6)
+ >>> c.common.c.shape
+ (2, 4, 5)
+
+ .. note::
+
+ If there are keys that are not shared across all batches, ``stack``
+ with ``axis != 0`` is undefined, and will cause an exception.
+ """
+ batch = Batch()
+ batch.stack_(batches, axis)
+ return batch
+
+ def empty_(self, index: Optional[Union[slice, IndexType]] = None) -> "Batch":
+ """Return an empty Batch object with 0 or None filled.
+
+ If "index" is specified, it will only reset the specific indexed-data.
+ ::
+
+ >>> data.empty_()
+ >>> print(data)
+ Batch(
+ a: array([[0., 0.],
+ [0., 0.]]),
+ b: array([None, None], dtype=object),
+ )
+ >>> b={'c': [2., 'st'], 'd': [1., 0.]}
+ >>> data = Batch(a=[False, True], b=b)
+ >>> data[0] = Batch.empty(data[1])
+ >>> data
+ Batch(
+ a: array([False, True]),
+ b: Batch(
+ c: array([None, 'st']),
+ d: array([0., 0.]),
+ ),
+ )
+ """
+ for batch_key, obj in self.items():
+ if isinstance(obj, torch.Tensor): # most often case
+ self.__dict__[batch_key][index] = 0
+ elif obj is None:
+ continue
+ elif isinstance(obj, np.ndarray):
+ if obj.dtype == object:
+ self.__dict__[batch_key][index] = None
+ else:
+ self.__dict__[batch_key][index] = 0
+ elif isinstance(obj, Batch):
+ self.__dict__[batch_key].empty_(index=index)
+ else: # scalar value
+ warnings.warn(
+ "You are calling Batch.empty on a NumPy scalar, "
+ "which may cause undefined behaviors."
+ )
+ if _is_number(obj):
+ self.__dict__[batch_key] = obj.__class__(0)
+ else:
+ self.__dict__[batch_key] = None
+ return self
+
+ @staticmethod
+ def empty(batch: "Batch", index: Optional[IndexType] = None) -> "Batch":
+ """Return an empty Batch object with 0 or None filled.
+
+ The shape is the same as the given Batch.
+ """
+ return deepcopy(batch).empty_(index)
+
+ def update(
+ self, batch: Optional[Union[dict, "Batch"]] = None, **kwargs: Any
+ ) -> None:
+ """Update this batch from another dict/Batch."""
+ if batch is None:
+ self.update(kwargs)
+ return
+ for batch_key, obj in batch.items():
+ self.__dict__[batch_key] = _parse_value(obj)
+ if kwargs:
+ self.update(kwargs)
+
+ def __len__(self) -> int:
+ """Return len(self)."""
+ lens = []
+ for obj in self.__dict__.values():
+ if isinstance(obj, Batch) and obj.is_empty(recurse=True):
+ continue
+ elif hasattr(obj, "__len__") and (isinstance(obj, Batch) or obj.ndim > 0):
+ lens.append(len(obj))
+ else:
+ raise TypeError(f"Object {obj} in {self} has no len()")
+ if len(lens) == 0:
+ # empty batch has the shape of any, like the tensorflow '?' shape.
+ # So it has no length.
+ raise TypeError(f"Object {self} has no len()")
+ return min(lens)
+
+ def is_empty(self, recurse: bool = False) -> bool:
+ """Test if a Batch is empty.
+
+ If ``recurse=True``, it further tests the values of the object; else
+ it only tests the existence of any key.
+
+ ``b.is_empty(recurse=True)`` is mainly used to distinguish
+ ``Batch(a=Batch(a=Batch()))`` and ``Batch(a=1)``. They both raise
+ exceptions when applied to ``len()``, but the former can be used in
+ ``cat``, while the latter is a scalar and cannot be used in ``cat``.
+
+ Another usage is in ``__len__``, where we have to skip checking the
+ length of recursively empty Batch.
+ ::
+
+ >>> Batch().is_empty()
+ True
+ >>> Batch(a=Batch(), b=Batch(c=Batch())).is_empty()
+ False
+ >>> Batch(a=Batch(), b=Batch(c=Batch())).is_empty(recurse=True)
+ True
+ >>> Batch(d=1).is_empty()
+ False
+ >>> Batch(a=np.float64(1.0)).is_empty()
+ False
+ """
+ if len(self.__dict__) == 0:
+ return True
+ if not recurse:
+ return False
+ return all(
+ False if not isinstance(obj, Batch) else obj.is_empty(recurse=True)
+ for obj in self.values()
+ )
+
+ @property
+ def shape(self) -> List[int]:
+ """Return self.shape."""
+ if self.is_empty():
+ return []
+ else:
+ data_shape = []
+ for obj in self.__dict__.values():
+ try:
+ data_shape.append(list(obj.shape))
+ except AttributeError:
+ data_shape.append([])
+ return list(map(min, zip(*data_shape))) if len(data_shape) > 1 \
+ else data_shape[0]
+
+ def split(self,
+ size: int,
+ shuffle: bool = True,
+ merge_last: bool = False) -> Iterator["Batch"]:
+ """Split whole data into multiple small batches.
+
+ :param int size: divide the data batch with the given size, but one
+ batch if the length of the batch is smaller than "size".
+ :param bool shuffle: randomly shuffle the entire data batch if it is
+ True, otherwise remain in the same. Default to True.
+ :param bool merge_last: merge the last batch into the previous one.
+ Default to False.
+ """
+ length = len(self)
+ assert 1 <= size # size can be greater than length, return whole batch
+ if shuffle:
+ indices = np.random.permutation(length)
+ else:
+ indices = np.arange(length)
+ merge_last = merge_last and length % size > 0
+ for idx in range(0, length, size):
+ if merge_last and idx + size + size >= length:
+ yield self[indices[idx:]]
+ break
+ yield self[indices[idx:idx + size]]
diff --git a/tianshou/data/buffer/__init__.py b/tianshou/data/buffer/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tianshou/data/buffer/base.py b/tianshou/data/buffer/base.py
new file mode 100644
index 0000000..aa159bd
--- /dev/null
+++ b/tianshou/data/buffer/base.py
@@ -0,0 +1,389 @@
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import h5py
+import numpy as np
+
+from tianshou.data import Batch
+from tianshou.data.batch import _alloc_by_keys_diff, _create_value
+from tianshou.data.utils.converter import from_hdf5, to_hdf5
+
+
+class ReplayBuffer:
+ """:class:`~tianshou.data.ReplayBuffer` stores data generated from interaction \
+ between the policy and environment.
+
+ ReplayBuffer can be considered as a specialized form (or management) of Batch. It
+ stores all the data in a batch with circular-queue style.
+
+ For the example usage of ReplayBuffer, please check out Section Buffer in
+ :doc:`/tutorials/concepts`.
+
+ :param int size: the maximum size of replay buffer.
+ :param int stack_num: the frame-stack sampling argument, should be greater than or
+ equal to 1. Default to 1 (no stacking).
+ :param bool ignore_obs_next: whether to store obs_next. Default to False.
+ :param bool save_only_last_obs: only save the last obs/obs_next when it has a shape
+ of (timestep, ...) because of temporal stacking. Default to False.
+ :param bool sample_avail: the parameter indicating sampling only available index
+ when using frame-stack sampling method. Default to False.
+ """
+
+ _reserved_keys = (
+ "obs", "act", "rew", "terminated", "truncated", "done", "obs_next", "info",
+ "policy"
+ )
+ _input_keys = (
+ "obs", "act", "rew", "terminated", "truncated", "obs_next", "info", "policy"
+ )
+
+ def __init__(
+ self,
+ size: int,
+ stack_num: int = 1,
+ ignore_obs_next: bool = False,
+ save_only_last_obs: bool = False,
+ sample_avail: bool = False,
+ **kwargs: Any, # otherwise PrioritizedVectorReplayBuffer will cause TypeError
+ ) -> None:
+ self.options: Dict[str, Any] = {
+ "stack_num": stack_num,
+ "ignore_obs_next": ignore_obs_next,
+ "save_only_last_obs": save_only_last_obs,
+ "sample_avail": sample_avail,
+ }
+ super().__init__()
+ self.maxsize = int(size)
+ assert stack_num > 0, "stack_num should be greater than 0"
+ self.stack_num = stack_num
+ self._indices = np.arange(size)
+ self._save_obs_next = not ignore_obs_next
+ self._save_only_last_obs = save_only_last_obs
+ self._sample_avail = sample_avail
+ self._meta: Batch = Batch()
+ self._ep_rew: Union[float, np.ndarray]
+ self.reset()
+
+ def __len__(self) -> int:
+ """Return len(self)."""
+ return self._size
+
+ def __repr__(self) -> str:
+ """Return str(self)."""
+ return self.__class__.__name__ + self._meta.__repr__()[5:]
+
+ def __getattr__(self, key: str) -> Any:
+ """Return self.key."""
+ try:
+ return self._meta[key]
+ except KeyError as exception:
+ raise AttributeError from exception
+
+ def __setstate__(self, state: Dict[str, Any]) -> None:
+ """Unpickling interface.
+
+ We need it because pickling buffer does not work out-of-the-box
+ ("buffer.__getattr__" is customized).
+ """
+ self.__dict__.update(state)
+
+ def __setattr__(self, key: str, value: Any) -> None:
+ """Set self.key = value."""
+ assert (key not in self._reserved_keys
+ ), "key '{}' is reserved and cannot be assigned".format(key)
+ super().__setattr__(key, value)
+
+ def save_hdf5(self, path: str, compression: Optional[str] = None) -> None:
+ """Save replay buffer to HDF5 file."""
+ with h5py.File(path, "w") as f:
+ to_hdf5(self.__dict__, f, compression=compression)
+
+ @classmethod
+ def load_hdf5(cls, path: str, device: Optional[str] = None) -> "ReplayBuffer":
+ """Load replay buffer from HDF5 file."""
+ with h5py.File(path, "r") as f:
+ buf = cls.__new__(cls)
+ buf.__setstate__(from_hdf5(f, device=device)) # type: ignore
+ return buf
+
+ @classmethod
+ def from_data(
+ cls, obs: h5py.Dataset, act: h5py.Dataset, rew: h5py.Dataset,
+ terminated: h5py.Dataset, truncated: h5py.Dataset, done: h5py.Dataset,
+ obs_next: h5py.Dataset
+ ) -> "ReplayBuffer":
+ size = len(obs)
+ assert all(len(dset) == size for dset in [obs, act, rew, terminated,
+ truncated, done, obs_next]), \
+ "Lengths of all hdf5 datasets need to be equal."
+ buf = cls(size)
+ if size == 0:
+ return buf
+ batch = Batch(
+ obs=obs,
+ act=act,
+ rew=rew,
+ terminated=terminated,
+ truncated=truncated,
+ done=done,
+ obs_next=obs_next
+ )
+ buf.set_batch(batch)
+ buf._size = size
+ return buf
+
+ def reset(self, keep_statistics: bool = False) -> None:
+ """Clear all the data in replay buffer and episode statistics."""
+ self.last_index = np.array([0])
+ self._index = self._size = 0
+ if not keep_statistics:
+ self._ep_rew, self._ep_len, self._ep_idx = 0.0, 0, 0
+
+ def set_batch(self, batch: Batch) -> None:
+ """Manually choose the batch you want the ReplayBuffer to manage."""
+ assert len(batch) == self.maxsize and set(batch.keys()).issubset(
+ self._reserved_keys
+ ), "Input batch doesn't meet ReplayBuffer's data form requirement."
+ self._meta = batch
+
+ def unfinished_index(self) -> np.ndarray:
+ """Return the index of unfinished episode."""
+ last = (self._index - 1) % self._size if self._size else 0
+ return np.array([last] if not self.done[last] and self._size else [], int)
+
+ def prev(self, index: Union[int, np.ndarray]) -> np.ndarray:
+ """Return the index of previous transition.
+
+ The index won't be modified if it is the beginning of an episode.
+ """
+ index = (index - 1) % self._size
+ end_flag = self.done[index] | (index == self.last_index[0])
+ return (index + end_flag) % self._size
+
+ def next(self, index: Union[int, np.ndarray]) -> np.ndarray:
+ """Return the index of next transition.
+
+ The index won't be modified if it is the end of an episode.
+ """
+ end_flag = self.done[index] | (index == self.last_index[0])
+ return (index + (1 - end_flag)) % self._size
+
+ def update(self, buffer: "ReplayBuffer") -> np.ndarray:
+ """Move the data from the given buffer to current buffer.
+
+ Return the updated indices. If update fails, return an empty array.
+ """
+ if len(buffer) == 0 or self.maxsize == 0:
+ return np.array([], int)
+ stack_num, buffer.stack_num = buffer.stack_num, 1
+ from_indices = buffer.sample_indices(0) # get all available indices
+ buffer.stack_num = stack_num
+ if len(from_indices) == 0:
+ return np.array([], int)
+ to_indices = []
+ for _ in range(len(from_indices)):
+ to_indices.append(self._index)
+ self.last_index[0] = self._index
+ self._index = (self._index + 1) % self.maxsize
+ self._size = min(self._size + 1, self.maxsize)
+ to_indices = np.array(to_indices)
+ if self._meta.is_empty():
+ self._meta = _create_value( # type: ignore
+ buffer._meta, self.maxsize, stack=False)
+ self._meta[to_indices] = buffer._meta[from_indices]
+ return to_indices
+
+ def _add_index(self, rew: Union[float, np.ndarray],
+ done: bool) -> Tuple[int, Union[float, np.ndarray], int, int]:
+ """Maintain the buffer's state after adding one data batch.
+
+ Return (index_to_be_modified, episode_reward, episode_length,
+ episode_start_index).
+ """
+ self.last_index[0] = ptr = self._index
+ self._size = min(self._size + 1, self.maxsize)
+ self._index = (self._index + 1) % self.maxsize
+
+ self._ep_rew += rew
+ self._ep_len += 1
+
+ if done:
+ result = ptr, self._ep_rew, self._ep_len, self._ep_idx
+ self._ep_rew, self._ep_len, self._ep_idx = 0.0, 0, self._index
+ return result
+ else:
+ return ptr, self._ep_rew * 0.0, 0, self._ep_idx
+
+ def add(
+ self,
+ batch: Batch,
+ buffer_ids: Optional[Union[np.ndarray, List[int]]] = None
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
+ """Add a batch of data into replay buffer.
+
+ :param Batch batch: the input data batch. Its keys must belong to the 7
+ input keys, and "obs", "act", "rew", "terminated", "truncated" is
+ required.
+ :param buffer_ids: to make consistent with other buffer's add function; if it
+ is not None, we assume the input batch's first dimension is always 1.
+
+ Return (current_index, episode_reward, episode_length, episode_start_index). If
+ the episode is not finished, the return value of episode_length and
+ episode_reward is 0.
+ """
+ # preprocess batch
+ new_batch = Batch()
+ for key in set(self._input_keys).intersection(batch.keys()):
+ new_batch.__dict__[key] = batch[key]
+ batch = new_batch
+ batch.__dict__["done"] = np.logical_or(batch.terminated, batch.truncated)
+ assert set(["obs", "act", "rew", "terminated", "truncated",
+ "done"]).issubset(batch.keys())
+ stacked_batch = buffer_ids is not None
+ if stacked_batch:
+ assert len(batch) == 1
+ if self._save_only_last_obs:
+ batch.obs = batch.obs[:, -1] if stacked_batch else batch.obs[-1]
+ if not self._save_obs_next:
+ batch.pop("obs_next", None)
+ elif self._save_only_last_obs:
+ batch.obs_next = (
+ batch.obs_next[:, -1] if stacked_batch else batch.obs_next[-1]
+ )
+ # get ptr
+ if stacked_batch:
+ rew, done = batch.rew[0], batch.done[0]
+ else:
+ rew, done = batch.rew, batch.done
+ ptr, ep_rew, ep_len, ep_idx = list(
+ map(lambda x: np.array([x]), self._add_index(rew, done))
+ )
+ try:
+ self._meta[ptr] = batch
+ except ValueError:
+ stack = not stacked_batch
+ batch.rew = batch.rew.astype(float)
+ batch.done = batch.done.astype(bool)
+ batch.terminated = batch.terminated.astype(bool)
+ batch.truncated = batch.truncated.astype(bool)
+ if self._meta.is_empty():
+ self._meta = _create_value( # type: ignore
+ batch, self.maxsize, stack)
+ else: # dynamic key pops up in batch
+ _alloc_by_keys_diff(self._meta, batch, self.maxsize, stack)
+ self._meta[ptr] = batch
+ return ptr, ep_rew, ep_len, ep_idx
+
+ def sample_indices(self, batch_size: int) -> np.ndarray:
+ """Get a random sample of index with size = batch_size.
+
+ Return all available indices in the buffer if batch_size is 0; return an empty
+ numpy array if batch_size < 0 or no available index can be sampled.
+ """
+ if self.stack_num == 1 or not self._sample_avail: # most often case
+ if batch_size > 0:
+ return np.random.choice(self._size, batch_size)
+ elif batch_size == 0: # construct current available indices
+ return np.concatenate(
+ [np.arange(self._index, self._size),
+ np.arange(self._index)]
+ )
+ else:
+ return np.array([], int)
+ else:
+ if batch_size < 0:
+ return np.array([], int)
+ all_indices = prev_indices = np.concatenate(
+ [np.arange(self._index, self._size),
+ np.arange(self._index)]
+ )
+ for _ in range(self.stack_num - 2):
+ prev_indices = self.prev(prev_indices)
+ all_indices = all_indices[prev_indices != self.prev(prev_indices)]
+ if batch_size > 0:
+ return np.random.choice(all_indices, batch_size)
+ else:
+ return all_indices
+
+ def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
+ """Get a random sample from buffer with size = batch_size.
+
+ Return all the data in the buffer if batch_size is 0.
+
+ :return: Sample data and its corresponding index inside the buffer.
+ """
+ indices = self.sample_indices(batch_size)
+ return self[indices], indices
+
+ def get(
+ self,
+ index: Union[int, List[int], np.ndarray],
+ key: str,
+ default_value: Any = None,
+ stack_num: Optional[int] = None,
+ ) -> Union[Batch, np.ndarray]:
+ """Return the stacked result.
+
+ E.g., if you set ``key = "obs", stack_num = 4, index = t``, it returns the
+ stacked result as ``[obs[t-3], obs[t-2], obs[t-1], obs[t]]``.
+
+ :param index: the index for getting stacked data.
+ :param str key: the key to get, should be one of the reserved_keys.
+ :param default_value: if the given key's data is not found and default_value is
+ set, return this default_value.
+ :param int stack_num: Default to self.stack_num.
+ """
+ if key not in self._meta and default_value is not None:
+ return default_value
+ val = self._meta[key]
+ if stack_num is None:
+ stack_num = self.stack_num
+ try:
+ if stack_num == 1: # the most often case
+ return val[index]
+ stack: List[Any] = []
+ if isinstance(index, list):
+ indices = np.array(index)
+ else:
+ indices = index # type: ignore
+ for _ in range(stack_num):
+ stack = [val[indices]] + stack
+ indices = self.prev(indices)
+ if isinstance(val, Batch):
+ return Batch.stack(stack, axis=indices.ndim)
+ else:
+ return np.stack(stack, axis=indices.ndim)
+ except IndexError as exception:
+ if not (isinstance(val, Batch) and val.is_empty()):
+ raise exception # val != Batch()
+ return Batch()
+
+ def __getitem__(self, index: Union[slice, int, List[int], np.ndarray]) -> Batch:
+ """Return a data batch: self[index].
+
+ If stack_num is larger than 1, return the stacked obs and obs_next with shape
+ (batch, len, ...).
+ """
+ if isinstance(index, slice): # change slice to np array
+ # buffer[:] will get all available data
+ indices = self.sample_indices(0) if index == slice(None) \
+ else self._indices[:len(self)][index]
+ else:
+ indices = index # type: ignore
+ # raise KeyError first instead of AttributeError,
+ # to support np.array([ReplayBuffer()])
+ obs = self.get(indices, "obs")
+ if self._save_obs_next:
+ obs_next = self.get(indices, "obs_next", Batch())
+ else:
+ obs_next = self.get(self.next(indices), "obs", Batch())
+ return Batch(
+ obs=obs,
+ act=self.act[indices],
+ rew=self.rew[indices],
+ terminated=self.terminated[indices],
+ truncated=self.truncated[indices],
+ done=self.done[indices],
+ obs_next=obs_next,
+ info=self.get(indices, "info", Batch()),
+ policy=self.get(indices, "policy", Batch()),
+ )
diff --git a/tianshou/data/buffer/cached.py b/tianshou/data/buffer/cached.py
new file mode 100644
index 0000000..5063750
--- /dev/null
+++ b/tianshou/data/buffer/cached.py
@@ -0,0 +1,84 @@
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+
+from tianshou.data import Batch, ReplayBuffer, ReplayBufferManager
+
+
+class CachedReplayBuffer(ReplayBufferManager):
+ """CachedReplayBuffer contains a given main buffer and n cached buffers, \
+ ``cached_buffer_num * ReplayBuffer(size=max_episode_length)``.
+
+ The memory layout is: ``| main_buffer | cached_buffers[0] | cached_buffers[1] | ...
+ | cached_buffers[cached_buffer_num - 1] |``.
+
+ The data is first stored in cached buffers. When an episode is terminated, the data
+ will move to the main buffer and the corresponding cached buffer will be reset.
+
+ :param ReplayBuffer main_buffer: the main buffer whose ``.update()`` function
+ behaves normally.
+ :param int cached_buffer_num: number of ReplayBuffer needs to be created for cached
+ buffer.
+ :param int max_episode_length: the maximum length of one episode, used in each
+ cached buffer's maxsize.
+
+ .. seealso::
+
+ Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage.
+ """
+
+ def __init__(
+ self,
+ main_buffer: ReplayBuffer,
+ cached_buffer_num: int,
+ max_episode_length: int,
+ ) -> None:
+ assert cached_buffer_num > 0 and max_episode_length > 0
+ assert type(main_buffer) == ReplayBuffer
+ kwargs = main_buffer.options
+ buffers = [main_buffer] + [
+ ReplayBuffer(max_episode_length, **kwargs)
+ for _ in range(cached_buffer_num)
+ ]
+ super().__init__(buffer_list=buffers)
+ self.main_buffer = self.buffers[0]
+ self.cached_buffers = self.buffers[1:]
+ self.cached_buffer_num = cached_buffer_num
+
+ def add(
+ self,
+ batch: Batch,
+ buffer_ids: Optional[Union[np.ndarray, List[int]]] = None
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
+ """Add a batch of data into CachedReplayBuffer.
+
+ Each of the data's length (first dimension) must equal to the length of
+ buffer_ids. By default the buffer_ids is [0, 1, ..., cached_buffer_num - 1].
+
+ Return (current_index, episode_reward, episode_length, episode_start_index)
+ with each of the shape (len(buffer_ids), ...), where (current_index[i],
+ episode_reward[i], episode_length[i], episode_start_index[i]) refers to the
+ cached_buffer_ids[i]th cached buffer's corresponding episode result.
+ """
+ if buffer_ids is None:
+ buf_arr = np.arange(1, 1 + self.cached_buffer_num)
+ else: # make sure it is np.ndarray
+ buf_arr = np.asarray(buffer_ids) + 1
+ ptr, ep_rew, ep_len, ep_idx = super().add(batch, buffer_ids=buf_arr)
+ # find the terminated episode, move data from cached buf to main buf
+ updated_ptr, updated_ep_idx = [], []
+ done = np.logical_or(batch.terminated, batch.truncated)
+ for buffer_idx in buf_arr[done]:
+ index = self.main_buffer.update(self.buffers[buffer_idx])
+ if len(index) == 0: # unsuccessful move, replace with -1
+ index = [-1]
+ updated_ep_idx.append(index[0])
+ updated_ptr.append(index[-1])
+ self.buffers[buffer_idx].reset()
+ self._lengths[0] = len(self.main_buffer)
+ self._lengths[buffer_idx] = 0
+ self.last_index[0] = index[-1]
+ self.last_index[buffer_idx] = self._offset[buffer_idx]
+ ptr[done] = updated_ptr
+ ep_idx[done] = updated_ep_idx
+ return ptr, ep_rew, ep_len, ep_idx
diff --git a/tianshou/data/buffer/her.py b/tianshou/data/buffer/her.py
new file mode 100644
index 0000000..fc18243
--- /dev/null
+++ b/tianshou/data/buffer/her.py
@@ -0,0 +1,187 @@
+from typing import Any, Callable, List, Optional, Tuple, Union
+
+import numpy as np
+
+from tianshou.data import Batch, ReplayBuffer
+
+
+class HERReplayBuffer(ReplayBuffer):
+ """Implementation of Hindsight Experience Replay. arXiv:1707.01495.
+
+ HERReplayBuffer is to be used with goal-based environment where the
+ observation is a dictionary with keys ``observation``, ``achieved_goal`` and
+ ``desired_goal``. Currently support only HER's future strategy, online sampling.
+
+ :param int size: the size of the replay buffer.
+ :param compute_reward_fn: a function that takes 2 ``np.array`` arguments,
+ ``acheived_goal`` and ``desired_goal``, and returns rewards as ``np.array``.
+ The two arguments are of shape (batch_size, ...original_shape) and the returned
+ rewards must be of shape (batch_size,).
+ :param int horizon: the maximum number of steps in an episode.
+ :param int future_k: the 'k' parameter introduced in the paper. In short, there
+ will be at most k episodes that are re-written for every 1 unaltered episode
+ during the sampling.
+
+ .. seealso::
+
+ Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage.
+ """
+
+ def __init__(
+ self,
+ size: int,
+ compute_reward_fn: Callable[[np.ndarray, np.ndarray], np.ndarray],
+ horizon: int,
+ future_k: float = 8.0,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(size, **kwargs)
+ self.horizon = horizon
+ self.future_p = 1 - 1 / future_k
+ self.compute_reward_fn = compute_reward_fn
+ self._original_meta = Batch()
+ self._altered_indices = np.array([])
+
+ def _restore_cache(self) -> None:
+ """Write cached original meta back to `self._meta`.
+
+ It's called everytime before 'writing', 'sampling' or 'saving' the buffer.
+ """
+ if not hasattr(self, '_altered_indices'):
+ return
+
+ if self._altered_indices.size == 0:
+ return
+ self._meta[self._altered_indices] = self._original_meta
+ # Clean
+ self._original_meta = Batch()
+ self._altered_indices = np.array([])
+
+ def reset(self, keep_statistics: bool = False) -> None:
+ self._restore_cache()
+ return super().reset(keep_statistics)
+
+ def save_hdf5(self, path: str, compression: Optional[str] = None) -> None:
+ self._restore_cache()
+ return super().save_hdf5(path, compression)
+
+ def set_batch(self, batch: Batch) -> None:
+ self._restore_cache()
+ return super().set_batch(batch)
+
+ def update(self, buffer: Union["HERReplayBuffer", "ReplayBuffer"]) -> np.ndarray:
+ self._restore_cache()
+ return super().update(buffer)
+
+ def add(
+ self,
+ batch: Batch,
+ buffer_ids: Optional[Union[np.ndarray, List[int]]] = None
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
+ self._restore_cache()
+ return super().add(batch, buffer_ids)
+
+ def sample_indices(self, batch_size: int) -> np.ndarray:
+ """Get a random sample of index with size = batch_size.
+
+ Return all available indices in the buffer if batch_size is 0; return an \
+ empty numpy array if batch_size < 0 or no available index can be sampled. \
+ Additionally, some episodes of the sampled transitions will be re-written \
+ according to HER.
+ """
+ self._restore_cache()
+ indices = super().sample_indices(batch_size=batch_size)
+ self.rewrite_transitions(indices.copy())
+ return indices
+
+ def rewrite_transitions(self, indices: np.ndarray) -> None:
+ """Re-write the goal of some sampled transitions' episodes according to HER.
+
+ Currently applies only HER's 'future' strategy. The new goals will be written \
+ directly to the internal batch data temporarily and will be restored right \
+ before the next sampling or when using some of the buffer's method (e.g. \
+ `add`, `save_hdf5`, etc.). This is to make sure that n-step returns \
+ calculation etc., performs correctly without additional alteration.
+ """
+ if indices.size == 0:
+ return
+
+ # Sort indices keeping chronological order
+ indices[indices < self._index] += self.maxsize
+ indices = np.sort(indices)
+ indices[indices >= self.maxsize] -= self.maxsize
+
+ # Construct episode trajectories
+ indices = [indices]
+ for _ in range(self.horizon - 1):
+ indices.append(self.next(indices[-1]))
+ indices = np.stack(indices)
+
+ # Calculate future timestep to use
+ current = indices[0]
+ terminal = indices[-1]
+ episodes_len = (terminal - current + self.maxsize) % self.maxsize
+ future_offset = np.random.uniform(size=len(indices[0])) * episodes_len
+ future_offset = np.round(future_offset).astype(int)
+ future_t = (current + future_offset) % self.maxsize
+
+ # Compute indices
+ # open indices are used to find longest, unique trajectories among
+ # presented episodes
+ unique_ep_open_indices = np.sort(np.unique(terminal, return_index=True)[1])
+ unique_ep_indices = indices[:, unique_ep_open_indices]
+ # close indices are used to find max future_t among presented episodes
+ unique_ep_close_indices = np.hstack(
+ [(unique_ep_open_indices - 1)[1:],
+ len(terminal) - 1]
+ )
+ # episode indices that will be altered
+ her_ep_indices = np.random.choice(
+ len(unique_ep_open_indices),
+ size=int(len(unique_ep_open_indices) * self.future_p),
+ replace=False
+ )
+
+ # Cache original meta
+ self._altered_indices = unique_ep_indices.copy()
+ self._original_meta = self._meta[self._altered_indices].copy()
+
+ # Copy original obs, ep_rew (and obs_next), and obs of future time step
+ ep_obs = self[unique_ep_indices].obs
+ ep_rew = self[unique_ep_indices].rew
+ if self._save_obs_next:
+ ep_obs_next = self[unique_ep_indices].obs_next
+ future_obs = self[future_t[unique_ep_close_indices]].obs_next
+ else:
+ future_obs = self[self.next(future_t[unique_ep_close_indices])].obs
+
+ # Re-assign goals and rewards via broadcast assignment
+ ep_obs.desired_goal[:, her_ep_indices] = \
+ future_obs.achieved_goal[None, her_ep_indices]
+ if self._save_obs_next:
+ ep_obs_next.desired_goal[:, her_ep_indices] = \
+ future_obs.achieved_goal[None, her_ep_indices]
+ ep_rew[:, her_ep_indices] = \
+ self._compute_reward(ep_obs_next)[:, her_ep_indices]
+ else:
+ tmp_ep_obs_next = self[self.next(unique_ep_indices)].obs
+ ep_rew[:, her_ep_indices] = \
+ self._compute_reward(tmp_ep_obs_next)[:, her_ep_indices]
+
+ # Sanity check
+ assert ep_obs.desired_goal.shape[:2] == unique_ep_indices.shape
+ assert ep_obs.achieved_goal.shape[:2] == unique_ep_indices.shape
+ assert ep_rew.shape == unique_ep_indices.shape
+
+ # Re-write meta
+ self._meta.obs[unique_ep_indices] = ep_obs
+ if self._save_obs_next:
+ self._meta.obs_next[unique_ep_indices] = ep_obs_next
+ self._meta.rew[unique_ep_indices] = ep_rew.astype(np.float32)
+
+ def _compute_reward(self, obs: Batch, lead_dims: int = 2) -> np.ndarray:
+ lead_shape = obs.observation.shape[:lead_dims]
+ g = obs.desired_goal.reshape(-1, *obs.desired_goal.shape[lead_dims:])
+ ag = obs.achieved_goal.reshape(-1, *obs.achieved_goal.shape[lead_dims:])
+ rewards = self.compute_reward_fn(ag, g)
+ return rewards.reshape(*lead_shape, *rewards.shape[1:])
diff --git a/tianshou/data/buffer/manager.py b/tianshou/data/buffer/manager.py
new file mode 100644
index 0000000..b694c1a
--- /dev/null
+++ b/tianshou/data/buffer/manager.py
@@ -0,0 +1,297 @@
+from typing import List, Optional, Sequence, Tuple, Union
+
+import numpy as np
+from numba import njit
+
+from tianshou.data import Batch, HERReplayBuffer, PrioritizedReplayBuffer, ReplayBuffer
+from tianshou.data.batch import _alloc_by_keys_diff, _create_value
+
+
+class ReplayBufferManager(ReplayBuffer):
+ """ReplayBufferManager contains a list of ReplayBuffer with exactly the same \
+ configuration.
+
+ These replay buffers have contiguous memory layout, and the storage space each
+ buffer has is a shallow copy of the topmost memory.
+
+ :param buffer_list: a list of ReplayBuffer needed to be handled.
+
+ .. seealso::
+
+ Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage.
+ """
+
+ def __init__(
+ self, buffer_list: Union[List[ReplayBuffer], List[HERReplayBuffer]]
+ ) -> None:
+ self.buffer_num = len(buffer_list)
+ self.buffers = np.array(buffer_list, dtype=object)
+ offset, size = [], 0
+ buffer_type = type(self.buffers[0])
+ kwargs = self.buffers[0].options
+ for buf in self.buffers:
+ assert buf._meta.is_empty()
+ assert isinstance(buf, buffer_type) and buf.options == kwargs
+ offset.append(size)
+ size += buf.maxsize
+ self._offset = np.array(offset)
+ self._extend_offset = np.array(offset + [size])
+ self._lengths = np.zeros_like(offset)
+ super().__init__(size=size, **kwargs)
+ self._compile()
+ self._meta: Batch
+
+ def _compile(self) -> None:
+ lens = last = index = np.array([0])
+ offset = np.array([0, 1])
+ done = np.array([False, False])
+ _prev_index(index, offset, done, last, lens)
+ _next_index(index, offset, done, last, lens)
+
+ def __len__(self) -> int:
+ return int(self._lengths.sum())
+
+ def reset(self, keep_statistics: bool = False) -> None:
+ self.last_index = self._offset.copy()
+ self._lengths = np.zeros_like(self._offset)
+ for buf in self.buffers:
+ buf.reset(keep_statistics=keep_statistics)
+
+ def _set_batch_for_children(self) -> None:
+ for offset, buf in zip(self._offset, self.buffers):
+ buf.set_batch(self._meta[offset:offset + buf.maxsize])
+
+ def set_batch(self, batch: Batch) -> None:
+ super().set_batch(batch)
+ self._set_batch_for_children()
+
+ def unfinished_index(self) -> np.ndarray:
+ return np.concatenate(
+ [
+ buf.unfinished_index() + offset
+ for offset, buf in zip(self._offset, self.buffers)
+ ]
+ )
+
+ def prev(self, index: Union[int, np.ndarray]) -> np.ndarray:
+ if isinstance(index, (list, np.ndarray)):
+ return _prev_index(
+ np.asarray(index), self._extend_offset, self.done, self.last_index,
+ self._lengths
+ )
+ else:
+ return _prev_index(
+ np.array([index]), self._extend_offset, self.done, self.last_index,
+ self._lengths
+ )[0]
+
+ def next(self, index: Union[int, np.ndarray]) -> np.ndarray:
+ if isinstance(index, (list, np.ndarray)):
+ return _next_index(
+ np.asarray(index), self._extend_offset, self.done, self.last_index,
+ self._lengths
+ )
+ else:
+ return _next_index(
+ np.array([index]), self._extend_offset, self.done, self.last_index,
+ self._lengths
+ )[0]
+
+ def update(self, buffer: ReplayBuffer) -> np.ndarray:
+ """The ReplayBufferManager cannot be updated by any buffer."""
+ raise NotImplementedError
+
+ def add(
+ self,
+ batch: Batch,
+ buffer_ids: Optional[Union[np.ndarray, List[int]]] = None
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
+ """Add a batch of data into ReplayBufferManager.
+
+ Each of the data's length (first dimension) must equal to the length of
+ buffer_ids. By default buffer_ids is [0, 1, ..., buffer_num - 1].
+
+ Return (current_index, episode_reward, episode_length, episode_start_index). If
+ the episode is not finished, the return value of episode_length and
+ episode_reward is 0.
+ """
+ # preprocess batch
+ new_batch = Batch()
+ for key in set(self._reserved_keys).intersection(batch.keys()):
+ new_batch.__dict__[key] = batch[key]
+ batch = new_batch
+ batch.__dict__["done"] = np.logical_or(batch.terminated, batch.truncated)
+ assert set(["obs", "act", "rew", "terminated", "truncated",
+ "done"]).issubset(batch.keys())
+ if self._save_only_last_obs:
+ batch.obs = batch.obs[:, -1]
+ if not self._save_obs_next:
+ batch.pop("obs_next", None)
+ elif self._save_only_last_obs:
+ batch.obs_next = batch.obs_next[:, -1]
+ # get index
+ if buffer_ids is None:
+ buffer_ids = np.arange(self.buffer_num)
+ ptrs, ep_lens, ep_rews, ep_idxs = [], [], [], []
+ for batch_idx, buffer_id in enumerate(buffer_ids):
+ ptr, ep_rew, ep_len, ep_idx = self.buffers[buffer_id]._add_index(
+ batch.rew[batch_idx], batch.done[batch_idx]
+ )
+ ptrs.append(ptr + self._offset[buffer_id])
+ ep_lens.append(ep_len)
+ ep_rews.append(ep_rew)
+ ep_idxs.append(ep_idx + self._offset[buffer_id])
+ self.last_index[buffer_id] = ptr + self._offset[buffer_id]
+ self._lengths[buffer_id] = len(self.buffers[buffer_id])
+ ptrs = np.array(ptrs)
+ try:
+ self._meta[ptrs] = batch
+ except ValueError:
+ batch.rew = batch.rew.astype(float)
+ batch.done = batch.done.astype(bool)
+ batch.terminated = batch.terminated.astype(bool)
+ batch.truncated = batch.truncated.astype(bool)
+ if self._meta.is_empty():
+ self._meta = _create_value( # type: ignore
+ batch, self.maxsize, stack=False)
+ else: # dynamic key pops up in batch
+ _alloc_by_keys_diff(self._meta, batch, self.maxsize, False)
+ self._set_batch_for_children()
+ self._meta[ptrs] = batch
+ return ptrs, np.array(ep_rews), np.array(ep_lens), np.array(ep_idxs)
+
+ def sample_indices(self, batch_size: int) -> np.ndarray:
+ if batch_size < 0:
+ return np.array([], int)
+ if self._sample_avail and self.stack_num > 1:
+ all_indices = np.concatenate(
+ [
+ buf.sample_indices(0) + offset
+ for offset, buf in zip(self._offset, self.buffers)
+ ]
+ )
+ if batch_size == 0:
+ return all_indices
+ else:
+ return np.random.choice(all_indices, batch_size)
+ if batch_size == 0: # get all available indices
+ sample_num = np.zeros(self.buffer_num, int)
+ else:
+ buffer_idx = np.random.choice(
+ self.buffer_num, batch_size, p=self._lengths / self._lengths.sum()
+ )
+ sample_num = np.bincount(buffer_idx, minlength=self.buffer_num)
+ # avoid batch_size > 0 and sample_num == 0 -> get child's all data
+ sample_num[sample_num == 0] = -1
+
+ return np.concatenate(
+ [
+ buf.sample_indices(bsz) + offset
+ for offset, buf, bsz in zip(self._offset, self.buffers, sample_num)
+ ]
+ )
+
+
+class PrioritizedReplayBufferManager(PrioritizedReplayBuffer, ReplayBufferManager):
+ """PrioritizedReplayBufferManager contains a list of PrioritizedReplayBuffer with \
+ exactly the same configuration.
+
+ These replay buffers have contiguous memory layout, and the storage space each
+ buffer has is a shallow copy of the topmost memory.
+
+ :param buffer_list: a list of PrioritizedReplayBuffer needed to be handled.
+
+ .. seealso::
+
+ Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage.
+ """
+
+ def __init__(self, buffer_list: Sequence[PrioritizedReplayBuffer]) -> None:
+ ReplayBufferManager.__init__(self, buffer_list) # type: ignore
+ kwargs = buffer_list[0].options
+ for buf in buffer_list:
+ del buf.weight
+ PrioritizedReplayBuffer.__init__(self, self.maxsize, **kwargs)
+
+
+class HERReplayBufferManager(ReplayBufferManager):
+ """HERReplayBufferManager contains a list of HERReplayBuffer with \
+ exactly the same configuration.
+
+ These replay buffers have contiguous memory layout, and the storage space each
+ buffer has is a shallow copy of the topmost memory.
+
+ :param buffer_list: a list of HERReplayBuffer needed to be handled.
+
+ .. seealso::
+
+ Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage.
+ """
+
+ def __init__(self, buffer_list: List[HERReplayBuffer]) -> None:
+ super().__init__(buffer_list)
+
+ def _restore_cache(self) -> None:
+ for buf in self.buffers:
+ buf._restore_cache()
+
+ def save_hdf5(self, path: str, compression: Optional[str] = None) -> None:
+ self._restore_cache()
+ return super().save_hdf5(path, compression)
+
+ def set_batch(self, batch: Batch) -> None:
+ self._restore_cache()
+ return super().set_batch(batch)
+
+ def update(self, buffer: Union["HERReplayBuffer", "ReplayBuffer"]) -> np.ndarray:
+ self._restore_cache()
+ return super().update(buffer)
+
+ def add(
+ self,
+ batch: Batch,
+ buffer_ids: Optional[Union[np.ndarray, List[int]]] = None
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
+ self._restore_cache()
+ return super().add(batch, buffer_ids)
+
+
+@njit
+def _prev_index(
+ index: np.ndarray,
+ offset: np.ndarray,
+ done: np.ndarray,
+ last_index: np.ndarray,
+ lengths: np.ndarray,
+) -> np.ndarray:
+ index = index % offset[-1]
+ prev_index = np.zeros_like(index)
+ for start, end, cur_len, last in zip(offset[:-1], offset[1:], lengths, last_index):
+ mask = (start <= index) & (index < end)
+ cur_len = max(1, cur_len)
+ if np.sum(mask) > 0:
+ subind = index[mask]
+ subind = (subind - start - 1) % cur_len
+ end_flag = done[subind + start] | (subind + start == last)
+ prev_index[mask] = (subind + end_flag) % cur_len + start
+ return prev_index
+
+
+@njit
+def _next_index(
+ index: np.ndarray,
+ offset: np.ndarray,
+ done: np.ndarray,
+ last_index: np.ndarray,
+ lengths: np.ndarray,
+) -> np.ndarray:
+ index = index % offset[-1]
+ next_index = np.zeros_like(index)
+ for start, end, cur_len, last in zip(offset[:-1], offset[1:], lengths, last_index):
+ mask = (start <= index) & (index < end)
+ cur_len = max(1, cur_len)
+ if np.sum(mask) > 0:
+ subind = index[mask]
+ end_flag = done[subind] | (subind == last)
+ next_index[mask] = (subind - start + 1 - end_flag) % cur_len + start
+ return next_index
diff --git a/tianshou/data/buffer/prio.py b/tianshou/data/buffer/prio.py
new file mode 100644
index 0000000..dc30a5d
--- /dev/null
+++ b/tianshou/data/buffer/prio.py
@@ -0,0 +1,105 @@
+from typing import Any, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from tianshou.data import Batch, ReplayBuffer, SegmentTree, to_numpy
+
+
+class PrioritizedReplayBuffer(ReplayBuffer):
+ """Implementation of Prioritized Experience Replay. arXiv:1511.05952.
+
+ :param float alpha: the prioritization exponent.
+ :param float beta: the importance sample soft coefficient.
+ :param bool weight_norm: whether to normalize returned weights with the maximum
+ weight value within the batch. Default to True.
+
+ .. seealso::
+
+ Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage.
+ """
+
+ def __init__(
+ self,
+ size: int,
+ alpha: float,
+ beta: float,
+ weight_norm: bool = True,
+ **kwargs: Any
+ ) -> None:
+ # will raise KeyError in PrioritizedVectorReplayBuffer
+ # super().__init__(size, **kwargs)
+ ReplayBuffer.__init__(self, size, **kwargs)
+ assert alpha > 0.0 and beta >= 0.0
+ self._alpha, self._beta = alpha, beta
+ self._max_prio = self._min_prio = 1.0
+ # save weight directly in this class instead of self._meta
+ self.weight = SegmentTree(size)
+ self.__eps = np.finfo(np.float32).eps.item()
+ self.options.update(alpha=alpha, beta=beta)
+ self._weight_norm = weight_norm
+
+ def init_weight(self, index: Union[int, np.ndarray]) -> None:
+ self.weight[index] = self._max_prio**self._alpha
+
+ def update(self, buffer: ReplayBuffer) -> np.ndarray:
+ indices = super().update(buffer)
+ self.init_weight(indices)
+ return indices
+
+ def add(
+ self,
+ batch: Batch,
+ buffer_ids: Optional[Union[np.ndarray, List[int]]] = None
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
+ ptr, ep_rew, ep_len, ep_idx = super().add(batch, buffer_ids)
+ self.init_weight(ptr)
+ return ptr, ep_rew, ep_len, ep_idx
+
+ def sample_indices(self, batch_size: int) -> np.ndarray:
+ if batch_size > 0 and len(self) > 0:
+ scalar = np.random.rand(batch_size) * self.weight.reduce()
+ return self.weight.get_prefix_sum_idx(scalar) # type: ignore
+ else:
+ return super().sample_indices(batch_size)
+
+ def get_weight(self, index: Union[int, np.ndarray]) -> Union[float, np.ndarray]:
+ """Get the importance sampling weight.
+
+ The "weight" in the returned Batch is the weight on loss function to debias
+ the sampling process (some transition tuples are sampled more often so their
+ losses are weighted less).
+ """
+ # important sampling weight calculation
+ # original formula: ((p_j/p_sum*N)**(-beta))/((p_min/p_sum*N)**(-beta))
+ # simplified formula: (p_j/p_min)**(-beta)
+ return (self.weight[index] / self._min_prio)**(-self._beta)
+
+ def update_weight(
+ self, index: np.ndarray, new_weight: Union[np.ndarray, torch.Tensor]
+ ) -> None:
+ """Update priority weight by index in this buffer.
+
+ :param np.ndarray index: index you want to update weight.
+ :param np.ndarray new_weight: new priority weight you want to update.
+ """
+ weight = np.abs(to_numpy(new_weight)) + self.__eps
+ self.weight[index] = weight**self._alpha
+ self._max_prio = max(self._max_prio, weight.max())
+ self._min_prio = min(self._min_prio, weight.min())
+
+ def __getitem__(self, index: Union[slice, int, List[int], np.ndarray]) -> Batch:
+ if isinstance(index, slice): # change slice to np array
+ # buffer[:] will get all available data
+ indices = self.sample_indices(0) if index == slice(None) \
+ else self._indices[:len(self)][index]
+ else:
+ indices = index # type: ignore
+ batch = super().__getitem__(indices)
+ weight = self.get_weight(indices)
+ # ref: https://github.com/Kaixhin/Rainbow/blob/master/memory.py L154
+ batch.weight = weight / np.max(weight) if self._weight_norm else weight
+ return batch
+
+ def set_beta(self, beta: float) -> None:
+ self._beta = beta
diff --git a/tianshou/data/buffer/vecbuf.py b/tianshou/data/buffer/vecbuf.py
new file mode 100644
index 0000000..a08ad98
--- /dev/null
+++ b/tianshou/data/buffer/vecbuf.py
@@ -0,0 +1,91 @@
+from typing import Any
+
+import numpy as np
+
+from tianshou.data import (
+ HERReplayBuffer,
+ HERReplayBufferManager,
+ PrioritizedReplayBuffer,
+ PrioritizedReplayBufferManager,
+ ReplayBuffer,
+ ReplayBufferManager,
+)
+
+
+class VectorReplayBuffer(ReplayBufferManager):
+ """VectorReplayBuffer contains n ReplayBuffer with the same size.
+
+ It is used for storing transition from different environments yet keeping the order
+ of time.
+
+ :param int total_size: the total size of VectorReplayBuffer.
+ :param int buffer_num: the number of ReplayBuffer it uses, which are under the same
+ configuration.
+
+ Other input arguments (stack_num/ignore_obs_next/save_only_last_obs/sample_avail)
+ are the same as :class:`~tianshou.data.ReplayBuffer`.
+
+ .. seealso::
+
+ Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage.
+ """
+
+ def __init__(self, total_size: int, buffer_num: int, **kwargs: Any) -> None:
+ assert buffer_num > 0
+ size = int(np.ceil(total_size / buffer_num))
+ buffer_list = [ReplayBuffer(size, **kwargs) for _ in range(buffer_num)]
+ super().__init__(buffer_list)
+
+
+class PrioritizedVectorReplayBuffer(PrioritizedReplayBufferManager):
+ """PrioritizedVectorReplayBuffer contains n PrioritizedReplayBuffer with same size.
+
+ It is used for storing transition from different environments yet keeping the order
+ of time.
+
+ :param int total_size: the total size of PrioritizedVectorReplayBuffer.
+ :param int buffer_num: the number of PrioritizedReplayBuffer it uses, which are
+ under the same configuration.
+
+ Other input arguments (alpha/beta/stack_num/ignore_obs_next/save_only_last_obs/
+ sample_avail) are the same as :class:`~tianshou.data.PrioritizedReplayBuffer`.
+
+ .. seealso::
+
+ Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage.
+ """
+
+ def __init__(self, total_size: int, buffer_num: int, **kwargs: Any) -> None:
+ assert buffer_num > 0
+ size = int(np.ceil(total_size / buffer_num))
+ buffer_list = [
+ PrioritizedReplayBuffer(size, **kwargs) for _ in range(buffer_num)
+ ]
+ super().__init__(buffer_list)
+
+ def set_beta(self, beta: float) -> None:
+ for buffer in self.buffers:
+ buffer.set_beta(beta)
+
+
+class HERVectorReplayBuffer(HERReplayBufferManager):
+ """HERVectorReplayBuffer contains n HERReplayBuffer with same size.
+
+ It is used for storing transition from different environments yet keeping the order
+ of time.
+
+ :param int total_size: the total size of HERVectorReplayBuffer.
+ :param int buffer_num: the number of HERReplayBuffer it uses, which are
+ under the same configuration.
+
+ Other input arguments are the same as :class:`~tianshou.data.HERReplayBuffer`.
+
+ .. seealso::
+ Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage.
+ """
+
+ def __init__(self, total_size: int, buffer_num: int, **kwargs: Any) -> None:
+ assert buffer_num > 0
+ size = int(np.ceil(total_size / buffer_num))
+ buffer_list = [HERReplayBuffer(size, **kwargs) for _ in range(buffer_num)]
+ super().__init__(buffer_list)
diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py
new file mode 100644
index 0000000..3b41b8e
--- /dev/null
+++ b/tianshou/data/collector.py
@@ -0,0 +1,672 @@
+import time
+import warnings
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import gymnasium as gym
+import numpy as np
+import torch
+
+from tianshou.data import (
+ Batch,
+ CachedReplayBuffer,
+ ReplayBuffer,
+ ReplayBufferManager,
+ VectorReplayBuffer,
+ to_numpy,
+)
+from tianshou.data.batch import _alloc_by_keys_diff
+from tianshou.env import BaseVectorEnv, DummyVectorEnv
+from tianshou.policy import BasePolicy
+
+
+class Collector(object):
+ """Collector enables the policy to interact with different types of envs with \
+ exact number of steps or episodes.
+
+ :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
+ :param env: a ``gym.Env`` environment or an instance of the
+ :class:`~tianshou.env.BaseVectorEnv` class.
+ :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class.
+ If set to None, it will not store the data. Default to None.
+ :param function preprocess_fn: a function called before the data has been added to
+ the buffer, see issue #42 and :ref:`preprocess_fn`. Default to None.
+ :param bool exploration_noise: determine whether the action needs to be modified
+ with corresponding policy's exploration noise. If so, "policy.
+ exploration_noise(act, batch)" will be called automatically to add the
+ exploration noise into action. Default to False.
+
+ The "preprocess_fn" is a function called before the data has been added to the
+ buffer with batch format. It will receive only "obs" and "env_id" when the
+ collector resets the environment, and will receive the keys "obs_next", "rew",
+ "terminated", "truncated, "info", "policy" and "env_id" in a normal env step.
+ Alternatively, it may also accept the keys "obs_next", "rew", "done", "info",
+ "policy" and "env_id".
+ It returns either a dict or a :class:`~tianshou.data.Batch` with the modified
+ keys and values. Examples are in "test/base/test_collector.py".
+
+ .. note::
+
+ Please make sure the given environment has a time limitation if using n_episode
+ collect option.
+
+ .. note::
+
+ In past versions of Tianshou, the replay buffer that was passed to `__init__`
+ was automatically reset. This is not done in the current implementation.
+ """
+
+ def __init__(
+ self,
+ policy: BasePolicy,
+ env: Union[gym.Env, BaseVectorEnv],
+ buffer: Optional[ReplayBuffer] = None,
+ preprocess_fn: Optional[Callable[..., Batch]] = None,
+ exploration_noise: bool = False,
+ ) -> None:
+ super().__init__()
+ if isinstance(env, gym.Env) and not hasattr(env, "__len__"):
+ warnings.warn("Single environment detected, wrap to DummyVectorEnv.")
+ self.env = DummyVectorEnv([lambda: env]) # type: ignore
+ else:
+ self.env = env # type: ignore
+ self.env_num = len(self.env)
+ self.exploration_noise = exploration_noise
+ self._assign_buffer(buffer)
+ self.policy = policy
+ self.preprocess_fn = preprocess_fn
+ self._action_space = self.env.action_space
+ # avoid creating attribute outside __init__
+ self.reset(False)
+
+ def _assign_buffer(self, buffer: Optional[ReplayBuffer]) -> None:
+ """Check if the buffer matches the constraint."""
+ if buffer is None:
+ buffer = VectorReplayBuffer(self.env_num, self.env_num)
+ elif isinstance(buffer, ReplayBufferManager):
+ assert buffer.buffer_num >= self.env_num
+ if isinstance(buffer, CachedReplayBuffer):
+ assert buffer.cached_buffer_num >= self.env_num
+ else: # ReplayBuffer or PrioritizedReplayBuffer
+ assert buffer.maxsize > 0
+ if self.env_num > 1:
+ if type(buffer) == ReplayBuffer:
+ buffer_type = "ReplayBuffer"
+ vector_type = "VectorReplayBuffer"
+ else:
+ buffer_type = "PrioritizedReplayBuffer"
+ vector_type = "PrioritizedVectorReplayBuffer"
+ raise TypeError(
+ f"Cannot use {buffer_type}(size={buffer.maxsize}, ...) to collect "
+ f"{self.env_num} envs,\n\tplease use {vector_type}(total_size="
+ f"{buffer.maxsize}, buffer_num={self.env_num}, ...) instead."
+ )
+ self.buffer = buffer
+
+ def reset(
+ self,
+ reset_buffer: bool = True,
+ gym_reset_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> None:
+ """Reset the environment, statistics, current data and possibly replay memory.
+
+ :param bool reset_buffer: if true, reset the replay buffer that is attached
+ to the collector.
+ :param gym_reset_kwargs: extra keyword arguments to pass into the environment's
+ reset function. Defaults to None (extra keyword arguments)
+ """
+ # use empty Batch for "state" so that self.data supports slicing
+ # convert empty Batch to None when passing data to policy
+ self.data = Batch(
+ obs={},
+ act={},
+ rew={},
+ terminated={},
+ truncated={},
+ done={},
+ obs_next={},
+ info={},
+ policy={}
+ )
+ self.reset_env(gym_reset_kwargs)
+ if reset_buffer:
+ self.reset_buffer()
+ self.reset_stat()
+
+ def reset_stat(self) -> None:
+ """Reset the statistic variables."""
+ self.collect_step, self.collect_episode, self.collect_time = 0, 0, 0.0
+
+ def reset_buffer(self, keep_statistics: bool = False) -> None:
+ """Reset the data buffer."""
+ self.buffer.reset(keep_statistics=keep_statistics)
+
+ def reset_env(self, gym_reset_kwargs: Optional[Dict[str, Any]] = None) -> None:
+ """Reset all of the environments."""
+ gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {}
+ obs, info = self.env.reset(**gym_reset_kwargs)
+ if self.preprocess_fn:
+ processed_data = self.preprocess_fn(
+ obs=obs, info=info, env_id=np.arange(self.env_num)
+ )
+ obs = processed_data.get("obs", obs)
+ info = processed_data.get("info", info)
+ self.data.info = info
+ self.data.obs = obs
+
+ def _reset_state(self, id: Union[int, List[int]]) -> None:
+ """Reset the hidden state: self.data.state[id]."""
+ if hasattr(self.data.policy, "hidden_state"):
+ state = self.data.policy.hidden_state # it is a reference
+ if isinstance(state, torch.Tensor):
+ state[id].zero_()
+ elif isinstance(state, np.ndarray):
+ state[id] = None if state.dtype == object else 0
+ elif isinstance(state, Batch):
+ state.empty_(id)
+
+ def _reset_env_with_ids(
+ self,
+ local_ids: Union[List[int], np.ndarray],
+ global_ids: Union[List[int], np.ndarray],
+ gym_reset_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> None:
+ gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {}
+ obs_reset, info = self.env.reset(global_ids, **gym_reset_kwargs)
+ if self.preprocess_fn:
+ processed_data = self.preprocess_fn(
+ obs=obs_reset, info=info, env_id=global_ids
+ )
+ obs_reset = processed_data.get("obs", obs_reset)
+ info = processed_data.get("info", info)
+ self.data.info[local_ids] = info
+
+ self.data.obs_next[local_ids] = obs_reset
+
+ def collect(
+ self,
+ n_step: Optional[int] = None,
+ n_episode: Optional[int] = None,
+ random: bool = False,
+ render: Optional[float] = None,
+ no_grad: bool = True,
+ gym_reset_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Dict[str, Any]:
+ """Collect a specified number of step or episode.
+
+ To ensure unbiased sampling result with n_episode option, this function will
+ first collect ``n_episode - env_num`` episodes, then for the last ``env_num``
+ episodes, they will be collected evenly from each env.
+
+ :param int n_step: how many steps you want to collect.
+ :param int n_episode: how many episodes you want to collect.
+ :param bool random: whether to use random policy for collecting data. Default
+ to False. 需要考虑在训练过程中是否使用随即策略来收集数据 (答案是不需要)
+ :param float render: the sleep time between rendering consecutive frames.
+ Default to None (no rendering).
+ :param bool no_grad: whether to retain gradient in policy.forward(). Default to
+ True (no gradient retaining).
+ :param gym_reset_kwargs: extra keyword arguments to pass into the environment's
+ reset function. Defaults to None (extra keyword arguments)
+
+ .. note::
+
+ One and only one collection number specification is permitted, either
+ ``n_step`` or ``n_episode``.
+
+ :return: A dict including the following keys
+
+ * ``n/ep`` collected number of episodes.
+ * ``n/st`` collected number of steps.
+ * ``rews`` array of episode reward over collected episodes.
+ * ``lens`` array of episode length over collected episodes.
+ * ``idxs`` array of episode start index in buffer over collected episodes.
+ * ``rew`` mean of episodic rewards.
+ * ``len`` mean of episodic lengths.
+ * ``rew_std`` standard error of episodic rewards.
+ * ``len_std`` standard error of episodic lengths.
+ """
+ assert not self.env.is_async, "Please use AsyncCollector if using async venv."
+ if n_step is not None:
+ assert n_episode is None, (
+ f"Only one of n_step or n_episode is allowed in Collector."
+ f"collect, got n_step={n_step}, n_episode={n_episode}."
+ )
+ assert n_step > 0
+ if not n_step % self.env_num == 0:
+ warnings.warn(
+ f"n_step={n_step} is not a multiple of #env ({self.env_num}), "
+ "which may cause extra transitions collected into the buffer."
+ )
+ ready_env_ids = np.arange(self.env_num)
+ elif n_episode is not None:
+ assert n_episode > 0
+ ready_env_ids = np.arange(min(self.env_num, n_episode))
+ self.data = self.data[:min(self.env_num, n_episode)]
+ else:
+ raise TypeError(
+ "Please specify at least one (either n_step or n_episode) "
+ "in AsyncCollector.collect()."
+ )
+
+ start_time = time.time()
+
+ step_count = 0
+ episode_count = 0
+ episode_rews = []
+ episode_lens = []
+ episode_start_indices = []
+
+ while True:
+ assert len(self.data) == len(ready_env_ids)
+ # restore the state: if the last state is None, it won't store
+ last_state = self.data.policy.pop("hidden_state", None)
+
+ # get the next action
+ if random:
+ try:
+ act_sample = [
+ self._action_space[i].sample() for i in ready_env_ids
+ ]
+ except TypeError: # envpool's action space is not for per-env
+ act_sample = [self._action_space.sample() for _ in ready_env_ids]
+ act_sample = self.policy.map_action_inverse(act_sample) # type: ignore
+ self.data.update(act=act_sample)
+ else:
+ if no_grad:
+ with torch.no_grad(): # faster than retain_grad version
+ # self.data.obs will be used by agent to get result
+ result = self.policy(self.data, last_state)
+ else:
+ result = self.policy(self.data, last_state)
+ # update state / act / policy into self.data
+ policy = result.get("policy", Batch())
+ assert isinstance(policy, Batch)
+ state = result.get("state", None)
+ if state is not None:
+ policy.hidden_state = state # save state into buffer
+ act = to_numpy(result.act)
+ if self.exploration_noise:
+ act = self.policy.exploration_noise(act, self.data)
+ self.data.update(policy=policy, act=act)
+
+ # get bounded and remapped actions first (not saved into buffer)
+ action_remap = self.policy.map_action(self.data.act)
+ # step in env
+ obs_next, rew, terminated, truncated, info = self.env.step(
+ action_remap, # type: ignore
+ ready_env_ids
+ )
+ done = np.logical_or(terminated, truncated)
+
+ self.data.update(
+ obs_next=obs_next,
+ rew=rew,
+ terminated=terminated,
+ truncated=truncated,
+ done=done,
+ info=info
+ )
+ if self.preprocess_fn:
+ self.data.update(
+ self.preprocess_fn(
+ obs_next=self.data.obs_next,
+ rew=self.data.rew,
+ done=self.data.done,
+ info=self.data.info,
+ policy=self.data.policy,
+ env_id=ready_env_ids,
+ act=self.data.act,
+ )
+ )
+
+ if render:
+ self.env.render()
+ if render > 0 and not np.isclose(render, 0):
+ time.sleep(render)
+
+ # add data into the buffer
+ ptr, ep_rew, ep_len, ep_idx = self.buffer.add(
+ self.data, buffer_ids=ready_env_ids
+ )
+
+ # collect statistics
+ step_count += len(ready_env_ids)
+
+ if np.any(done):
+ env_ind_local = np.where(done)[0] # done为True的env的索引
+ env_ind_global = ready_env_ids[env_ind_local]
+ episode_count += len(env_ind_local) # 记录完成的episode数
+ episode_lens.append(ep_len[env_ind_local])
+ episode_rews.append(ep_rew[env_ind_local])
+ episode_start_indices.append(ep_idx[env_ind_local])
+ # now we copy obs_next to obs, but since there might be
+ # finished episodes, we have to reset finished envs first.
+ self._reset_env_with_ids(
+ env_ind_local, env_ind_global, gym_reset_kwargs
+ )
+ for i in env_ind_local:
+ self._reset_state(i)
+
+ # remove surplus env id from ready_env_ids
+ # to avoid bias in selecting environments
+ if n_episode:
+ surplus_env_num = len(ready_env_ids) - (n_episode - episode_count)
+ if surplus_env_num > 0:
+ mask = np.ones_like(ready_env_ids, dtype=bool)
+ mask[env_ind_local[:surplus_env_num]] = False
+ ready_env_ids = ready_env_ids[mask]
+ self.data = self.data[mask]
+
+ self.data.obs = self.data.obs_next
+
+ if (n_step and step_count >= n_step) or \
+ (n_episode and episode_count >= n_episode):
+ break
+
+ # generate statistics
+ self.collect_step += step_count
+ self.collect_episode += episode_count
+ self.collect_time += max(time.time() - start_time, 1e-9)
+
+ if n_episode:
+ self.data = Batch(
+ obs={},
+ act={},
+ rew={},
+ terminated={},
+ truncated={},
+ done={},
+ obs_next={},
+ info={},
+ policy={}
+ )
+ self.reset_env()
+
+ if episode_count > 0:
+ rews, lens, idxs = list(
+ map(
+ np.concatenate,
+ [episode_rews, episode_lens, episode_start_indices]
+ )
+ )
+ rew_mean, rew_std = rews.mean(), rews.std()
+ len_mean, len_std = lens.mean(), lens.std()
+ else:
+ rews, lens, idxs = np.array([]), np.array([], int), np.array([], int)
+ rew_mean = rew_std = len_mean = len_std = 0
+
+ return {
+ "n/ep": episode_count,
+ "n/st": step_count,
+ "rews": rews,
+ "lens": lens,
+ "idxs": idxs,
+ "rew": rew_mean,
+ "len": len_mean,
+ "rew_std": rew_std,
+ "len_std": len_std,
+ }
+
+
+class AsyncCollector(Collector):
+ """Async Collector handles async vector environment.
+
+ The arguments are exactly the same as :class:`~tianshou.data.Collector`, please
+ refer to :class:`~tianshou.data.Collector` for more detailed explanation.
+ """
+
+ def __init__(
+ self,
+ policy: BasePolicy,
+ env: BaseVectorEnv,
+ buffer: Optional[ReplayBuffer] = None,
+ preprocess_fn: Optional[Callable[..., Batch]] = None,
+ exploration_noise: bool = False,
+ ) -> None:
+ # assert env.is_async
+ warnings.warn("Using async setting may collect extra transitions into buffer.")
+ super().__init__(
+ policy,
+ env,
+ buffer,
+ preprocess_fn,
+ exploration_noise,
+ )
+
+ def reset_env(self, gym_reset_kwargs: Optional[Dict[str, Any]] = None) -> None:
+ super().reset_env(gym_reset_kwargs)
+ self._ready_env_ids = np.arange(self.env_num)
+
+ def collect(
+ self,
+ n_step: Optional[int] = None,
+ n_episode: Optional[int] = None,
+ random: bool = False,
+ render: Optional[float] = None,
+ no_grad: bool = True,
+ gym_reset_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Dict[str, Any]:
+ """Collect a specified number of step or episode with async env setting.
+
+ This function doesn't collect exactly n_step or n_episode number of
+ transitions. Instead, in order to support async setting, it may collect more
+ than given n_step or n_episode transitions and save into buffer.
+
+ :param int n_step: how many steps you want to collect.
+ :param int n_episode: how many episodes you want to collect.
+ :param bool random: whether to use random policy for collecting data. Default
+ to False.
+ :param float render: the sleep time between rendering consecutive frames.
+ Default to None (no rendering).
+ :param bool no_grad: whether to retain gradient in policy.forward(). Default to
+ True (no gradient retaining).
+ :param gym_reset_kwargs: extra keyword arguments to pass into the environment's
+ reset function. Defaults to None (extra keyword arguments)
+
+ .. note::
+
+ One and only one collection number specification is permitted, either
+ ``n_step`` or ``n_episode``.
+
+ :return: A dict including the following keys
+
+ * ``n/ep`` collected number of episodes.
+ * ``n/st`` collected number of steps.
+ * ``rews`` array of episode reward over collected episodes.
+ * ``lens`` array of episode length over collected episodes.
+ * ``idxs`` array of episode start index in buffer over collected episodes.
+ * ``rew`` mean of episodic rewards.
+ * ``len`` mean of episodic lengths.
+ * ``rew_std`` standard error of episodic rewards.
+ * ``len_std`` standard error of episodic lengths.
+ """
+ # collect at least n_step or n_episode
+ if n_step is not None:
+ assert n_episode is None, (
+ "Only one of n_step or n_episode is allowed in Collector."
+ f"collect, got n_step={n_step}, n_episode={n_episode}."
+ )
+ assert n_step > 0
+ elif n_episode is not None:
+ assert n_episode > 0
+ else:
+ raise TypeError(
+ "Please specify at least one (either n_step or n_episode) "
+ "in AsyncCollector.collect()."
+ )
+
+ ready_env_ids = self._ready_env_ids
+
+ start_time = time.time()
+
+ step_count = 0
+ episode_count = 0
+ episode_rews = []
+ episode_lens = []
+ episode_start_indices = []
+
+ while True:
+ whole_data = self.data
+ self.data = self.data[ready_env_ids]
+ assert len(whole_data) == self.env_num # major difference
+ # restore the state: if the last state is None, it won't store
+ last_state = self.data.policy.pop("hidden_state", None)
+
+ # get the next action
+ if random:
+ try:
+ act_sample = [
+ self._action_space[i].sample() for i in ready_env_ids
+ ]
+ except TypeError: # envpool's action space is not for per-env
+ act_sample = [self._action_space.sample() for _ in ready_env_ids]
+ act_sample = self.policy.map_action_inverse(act_sample) # type: ignore
+ self.data.update(act=act_sample)
+ else:
+ if no_grad:
+ with torch.no_grad(): # faster than retain_grad version
+ # self.data.obs will be used by agent to get result
+ result = self.policy(self.data, last_state)
+ else:
+ result = self.policy(self.data, last_state)
+ # update state / act / policy into self.data
+ policy = result.get("policy", Batch())
+ assert isinstance(policy, Batch)
+ state = result.get("state", None)
+ if state is not None:
+ policy.hidden_state = state # save state into buffer
+ act = to_numpy(result.act)
+ if self.exploration_noise:
+ act = self.policy.exploration_noise(act, self.data)
+ self.data.update(policy=policy, act=act)
+
+ # save act/policy before env.step
+ try:
+ whole_data.act[ready_env_ids] = self.data.act
+ whole_data.policy[ready_env_ids] = self.data.policy
+ except ValueError:
+ _alloc_by_keys_diff(whole_data, self.data, self.env_num, False)
+ whole_data[ready_env_ids] = self.data # lots of overhead
+
+ # get bounded and remapped actions first (not saved into buffer)
+ action_remap = self.policy.map_action(self.data.act)
+ # step in env
+ obs_next, rew, terminated, truncated, info = self.env.step(
+ action_remap, # type: ignore
+ ready_env_ids
+ )
+ done = np.logical_or(terminated, truncated)
+
+ # change self.data here because ready_env_ids has changed
+ try:
+ ready_env_ids = info["env_id"]
+ except Exception:
+ ready_env_ids = np.array([i["env_id"] for i in info])
+ self.data = whole_data[ready_env_ids]
+
+ self.data.update(
+ obs_next=obs_next,
+ rew=rew,
+ terminated=terminated,
+ truncated=truncated,
+ info=info
+ )
+ if self.preprocess_fn:
+ try:
+ self.data.update(
+ self.preprocess_fn(
+ obs_next=self.data.obs_next,
+ rew=self.data.rew,
+ terminated=self.data.terminated,
+ truncated=self.data.truncated,
+ info=self.data.info,
+ env_id=ready_env_ids,
+ act=self.data.act,
+ )
+ )
+ except TypeError:
+ self.data.update(
+ self.preprocess_fn(
+ obs_next=self.data.obs_next,
+ rew=self.data.rew,
+ done=self.data.done,
+ info=self.data.info,
+ env_id=ready_env_ids,
+ act=self.data.act,
+ )
+ )
+
+ if render:
+ self.env.render()
+ if render > 0 and not np.isclose(render, 0):
+ time.sleep(render)
+
+ # add data into the buffer
+ ptr, ep_rew, ep_len, ep_idx = self.buffer.add(
+ self.data, buffer_ids=ready_env_ids
+ )
+
+ # collect statistics
+ step_count += len(ready_env_ids)
+
+ if np.any(done):
+ env_ind_local = np.where(done)[0]
+ env_ind_global = ready_env_ids[env_ind_local]
+ episode_count += len(env_ind_local)
+ episode_lens.append(ep_len[env_ind_local])
+ episode_rews.append(ep_rew[env_ind_local])
+ episode_start_indices.append(ep_idx[env_ind_local])
+ # now we copy obs_next to obs, but since there might be
+ # finished episodes, we have to reset finished envs first.
+ self._reset_env_with_ids(
+ env_ind_local, env_ind_global, gym_reset_kwargs
+ )
+ for i in env_ind_local:
+ self._reset_state(i)
+
+ try:
+ whole_data.obs[ready_env_ids] = self.data.obs_next
+ whole_data.rew[ready_env_ids] = self.data.rew
+ whole_data.done[ready_env_ids] = self.data.done
+ whole_data.info[ready_env_ids] = self.data.info
+ except ValueError:
+ _alloc_by_keys_diff(whole_data, self.data, self.env_num, False)
+ self.data.obs = self.data.obs_next
+ whole_data[ready_env_ids] = self.data # lots of overhead
+ self.data = whole_data
+
+ if (n_step and step_count >= n_step) or \
+ (n_episode and episode_count >= n_episode):
+ break
+
+ self._ready_env_ids = ready_env_ids
+
+ # generate statistics
+ self.collect_step += step_count
+ self.collect_episode += episode_count
+ self.collect_time += max(time.time() - start_time, 1e-9)
+
+ if episode_count > 0:
+ rews, lens, idxs = list(
+ map(
+ np.concatenate,
+ [episode_rews, episode_lens, episode_start_indices]
+ )
+ )
+ rew_mean, rew_std = rews.mean(), rews.std()
+ len_mean, len_std = lens.mean(), lens.std()
+ else:
+ rews, lens, idxs = np.array([]), np.array([], int), np.array([], int)
+ rew_mean = rew_std = len_mean = len_std = 0
+
+ return {
+ "n/ep": episode_count,
+ "n/st": step_count,
+ "rews": rews,
+ "lens": lens,
+ "idxs": idxs,
+ "rew": rew_mean,
+ "len": len_mean,
+ "rew_std": rew_std,
+ "len_std": len_std,
+ }
diff --git a/tianshou/data/utils/__init__.py b/tianshou/data/utils/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tianshou/data/utils/converter.py b/tianshou/data/utils/converter.py
new file mode 100644
index 0000000..68e7508
--- /dev/null
+++ b/tianshou/data/utils/converter.py
@@ -0,0 +1,155 @@
+import pickle
+from copy import deepcopy
+from numbers import Number
+from typing import Any, Dict, Optional, Union, no_type_check
+
+import h5py
+import numpy as np
+import torch
+
+from tianshou.data.batch import Batch, _parse_value
+
+
+@no_type_check
+def to_numpy(x: Any) -> Union[Batch, np.ndarray]:
+ """Return an object without torch.Tensor."""
+ if isinstance(x, torch.Tensor): # most often case
+ return x.detach().cpu().numpy()
+ elif isinstance(x, np.ndarray): # second often case
+ return x
+ elif isinstance(x, (np.number, np.bool_, Number)):
+ return np.asanyarray(x)
+ elif x is None:
+ return np.array(None, dtype=object)
+ elif isinstance(x, (dict, Batch)):
+ x = Batch(x) if isinstance(x, dict) else deepcopy(x)
+ x.to_numpy()
+ return x
+ elif isinstance(x, (list, tuple)):
+ return to_numpy(_parse_value(x))
+ else: # fallback
+ return np.asanyarray(x)
+
+
+@no_type_check
+def to_torch(
+ x: Any,
+ dtype: Optional[torch.dtype] = None,
+ device: Union[str, int, torch.device] = "cpu",
+) -> Union[Batch, torch.Tensor]:
+ """Return an object without np.ndarray."""
+ if isinstance(x, np.ndarray) and issubclass(
+ x.dtype.type, (np.bool_, np.number)
+ ): # most often case
+ x = torch.from_numpy(x).to(device)
+ if dtype is not None:
+ x = x.type(dtype)
+ return x
+ elif isinstance(x, torch.Tensor): # second often case
+ if dtype is not None:
+ x = x.type(dtype)
+ return x.to(device)
+ elif isinstance(x, (np.number, np.bool_, Number)):
+ return to_torch(np.asanyarray(x), dtype, device)
+ elif isinstance(x, (dict, Batch)):
+ x = Batch(x, copy=True) if isinstance(x, dict) else deepcopy(x)
+ x.to_torch(dtype, device)
+ return x
+ elif isinstance(x, (list, tuple)):
+ return to_torch(_parse_value(x), dtype, device)
+ else: # fallback
+ raise TypeError(f"object {x} cannot be converted to torch.")
+
+
+@no_type_check
+def to_torch_as(x: Any, y: torch.Tensor) -> Union[Batch, torch.Tensor]:
+ """Return an object without np.ndarray.
+
+ Same as ``to_torch(x, dtype=y.dtype, device=y.device)``.
+ """
+ assert isinstance(y, torch.Tensor)
+ return to_torch(x, dtype=y.dtype, device=y.device)
+
+
+# Note: object is used as a proxy for objects that can be pickled
+# Note: mypy does not support cyclic definition currently
+Hdf5ConvertibleValues = Union[int, float, Batch, np.ndarray, torch.Tensor, object,
+ "Hdf5ConvertibleType"]
+
+Hdf5ConvertibleType = Dict[str, Hdf5ConvertibleValues]
+
+
+def to_hdf5(
+ x: Hdf5ConvertibleType, y: h5py.Group, compression: Optional[str] = None
+) -> None:
+ """Copy object into HDF5 group."""
+
+ def to_hdf5_via_pickle(
+ x: object, y: h5py.Group, key: str, compression: Optional[str] = None
+ ) -> None:
+ """Pickle, convert to numpy array and write to HDF5 dataset."""
+ data = np.frombuffer(pickle.dumps(x), dtype=np.byte)
+ y.create_dataset(key, data=data, compression=compression)
+
+ for k, v in x.items():
+ if isinstance(v, (Batch, dict)):
+ # dicts and batches are both represented by groups
+ subgrp = y.create_group(k)
+ if isinstance(v, Batch):
+ subgrp_data = v.__getstate__()
+ subgrp.attrs["__data_type__"] = "Batch"
+ else:
+ subgrp_data = v
+ to_hdf5(subgrp_data, subgrp, compression=compression)
+ elif isinstance(v, torch.Tensor):
+ # PyTorch tensors are written to datasets
+ y.create_dataset(k, data=to_numpy(v), compression=compression)
+ y[k].attrs["__data_type__"] = "Tensor"
+ elif isinstance(v, np.ndarray):
+ try:
+ # NumPy arrays are written to datasets
+ y.create_dataset(k, data=v, compression=compression)
+ y[k].attrs["__data_type__"] = "ndarray"
+ except TypeError:
+ # If data type is not supported by HDF5 fall back to pickle.
+ # This happens if dtype=object (e.g. due to entries being None)
+ # and possibly in other cases like structured arrays.
+ try:
+ to_hdf5_via_pickle(v, y, k, compression=compression)
+ except Exception as exception:
+ raise RuntimeError(
+ f"Attempted to pickle {v.__class__.__name__} due to "
+ "data type not supported by HDF5 and failed."
+ ) from exception
+ y[k].attrs["__data_type__"] = "pickled_ndarray"
+ elif isinstance(v, (int, float)):
+ # ints and floats are stored as attributes of groups
+ y.attrs[k] = v
+ else: # resort to pickle for any other type of object
+ try:
+ to_hdf5_via_pickle(v, y, k, compression=compression)
+ except Exception as exception:
+ raise NotImplementedError(
+ f"No conversion to HDF5 for object of type '{type(v)}' "
+ "implemented and fallback to pickle failed."
+ ) from exception
+ y[k].attrs["__data_type__"] = v.__class__.__name__
+
+
+def from_hdf5(x: h5py.Group, device: Optional[str] = None) -> Hdf5ConvertibleValues:
+ """Restore object from HDF5 group."""
+ if isinstance(x, h5py.Dataset):
+ # handle datasets
+ if x.attrs["__data_type__"] == "ndarray":
+ return np.array(x)
+ elif x.attrs["__data_type__"] == "Tensor":
+ return torch.tensor(x, device=device)
+ else:
+ return pickle.loads(x[()])
+ else:
+ # handle groups representing a dict or a Batch
+ y = dict(x.attrs.items())
+ data_type = y.pop("__data_type__", None)
+ for k, v in x.items():
+ y[k] = from_hdf5(v, device)
+ return Batch(y) if data_type == "Batch" else y
diff --git a/tianshou/data/utils/segtree.py b/tianshou/data/utils/segtree.py
new file mode 100644
index 0000000..063675c
--- /dev/null
+++ b/tianshou/data/utils/segtree.py
@@ -0,0 +1,137 @@
+from typing import Optional, Union
+
+import numpy as np
+from numba import njit
+
+
+class SegmentTree:
+ """Implementation of Segment Tree.
+
+ The segment tree stores an array ``arr`` with size ``n``. It supports value
+ update and fast query of the sum for the interval ``[left, right)`` in
+ O(log n) time. The detailed procedure is as follows:
+
+ 1. Pad the array to have length of power of 2, so that leaf nodes in the \
+ segment tree have the same depth.
+ 2. Store the segment tree in a binary heap.
+
+ :param int size: the size of segment tree.
+ """
+
+ def __init__(self, size: int) -> None:
+ bound = 1
+ while bound < size:
+ bound *= 2
+ self._size = size
+ self._bound = bound
+ self._value = np.zeros([bound * 2])
+ self._compile()
+
+ def __len__(self) -> int:
+ return self._size
+
+ def __getitem__(self, index: Union[int, np.ndarray]) -> Union[float, np.ndarray]:
+ """Return self[index]."""
+ return self._value[index + self._bound]
+
+ def __setitem__(
+ self, index: Union[int, np.ndarray], value: Union[float, np.ndarray]
+ ) -> None:
+ """Update values in segment tree.
+
+ Duplicate values in ``index`` are handled by numpy: later index
+ overwrites previous ones.
+ ::
+
+ >>> a = np.array([1, 2, 3, 4])
+ >>> a[[0, 1, 0, 1]] = [4, 5, 6, 7]
+ >>> print(a)
+ [6 7 3 4]
+ """
+ if isinstance(index, int):
+ index, value = np.array([index]), np.array([value])
+ assert np.all(0 <= index) and np.all(index < self._size)
+ _setitem(self._value, index + self._bound, value)
+
+ def reduce(self, start: int = 0, end: Optional[int] = None) -> float:
+ """Return operation(value[start:end])."""
+ if start == 0 and end is None:
+ return self._value[1]
+ if end is None:
+ end = self._size
+ if end < 0:
+ end += self._size
+ return _reduce(self._value, start + self._bound - 1, end + self._bound)
+
+ def get_prefix_sum_idx(self, value: Union[float,
+ np.ndarray]) -> Union[int, np.ndarray]:
+ r"""Find the index with given value.
+
+ Return the minimum index for each ``v`` in ``value`` so that
+ :math:`v \le \mathrm{sums}_i`, where
+ :math:`\mathrm{sums}_i = \sum_{j = 0}^{i} \mathrm{arr}_j`.
+
+ .. warning::
+
+ Please make sure all of the values inside the segment tree are
+ non-negative when using this function.
+ """
+ assert np.all(value >= 0.0) and np.all(value < self._value[1])
+ single = False
+ if not isinstance(value, np.ndarray):
+ value = np.array([value])
+ single = True
+ index = _get_prefix_sum_idx(value, self._bound, self._value)
+ return index.item() if single else index
+
+ def _compile(self) -> None:
+ f64 = np.array([0, 1], dtype=np.float64)
+ f32 = np.array([0, 1], dtype=np.float32)
+ i64 = np.array([0, 1], dtype=np.int64)
+ _setitem(f64, i64, f64)
+ _setitem(f64, i64, f32)
+ _reduce(f64, 0, 1)
+ _get_prefix_sum_idx(f64, 1, f64)
+ _get_prefix_sum_idx(f32, 1, f64)
+
+
+@njit
+def _setitem(tree: np.ndarray, index: np.ndarray, value: np.ndarray) -> None:
+ """Numba version, 4x faster: 0.1 -> 0.024."""
+ tree[index] = value
+ while index[0] > 1:
+ index //= 2
+ tree[index] = tree[index * 2] + tree[index * 2 + 1]
+
+
+@njit
+def _reduce(tree: np.ndarray, start: int, end: int) -> float:
+ """Numba version, 2x faster: 0.009 -> 0.005."""
+ # nodes in (start, end) should be aggregated
+ result = 0.0
+ while end - start > 1: # (start, end) interval is not empty
+ if start % 2 == 0:
+ result += tree[start + 1]
+ start //= 2
+ if end % 2 == 1:
+ result += tree[end - 1]
+ end //= 2
+ return result
+
+
+@njit
+def _get_prefix_sum_idx(value: np.ndarray, bound: int, sums: np.ndarray) -> np.ndarray:
+ """Numba version (v0.51), 5x speed up with size=100000 and bsz=64.
+
+ vectorized np: 0.0923 (numpy best) -> 0.024 (now)
+ for-loop: 0.2914 -> 0.019 (but not so stable)
+ """
+ index = np.ones(value.shape, dtype=np.int64)
+ while index[0] < bound:
+ index *= 2
+ lsons = sums[index]
+ direct = lsons < value
+ value -= lsons * direct
+ index += direct
+ index -= bound
+ return index
diff --git a/tianshou/env/__init__.py b/tianshou/env/__init__.py
new file mode 100644
index 0000000..a00c3cd
--- /dev/null
+++ b/tianshou/env/__init__.py
@@ -0,0 +1,34 @@
+"""Env package."""
+
+from tianshou.env.gym_wrappers import (
+ ContinuousToDiscrete,
+ MultiDiscreteToDiscrete,
+ TruncatedAsTerminated,
+)
+from tianshou.env.venv_wrappers import VectorEnvNormObs, VectorEnvWrapper
+from tianshou.env.venvs import (
+ BaseVectorEnv,
+ DummyVectorEnv,
+ RayVectorEnv,
+ ShmemVectorEnv,
+ SubprocVectorEnv,
+)
+
+try:
+ from tianshou.env.pettingzoo_env import PettingZooEnv
+except ImportError:
+ pass
+
+__all__ = [
+ "BaseVectorEnv",
+ "DummyVectorEnv",
+ "SubprocVectorEnv",
+ "ShmemVectorEnv",
+ "RayVectorEnv",
+ "VectorEnvWrapper",
+ "VectorEnvNormObs",
+ "PettingZooEnv",
+ "ContinuousToDiscrete",
+ "MultiDiscreteToDiscrete",
+ "TruncatedAsTerminated",
+]
diff --git a/tianshou/env/gym_wrappers.py b/tianshou/env/gym_wrappers.py
new file mode 100644
index 0000000..c9ce66a
--- /dev/null
+++ b/tianshou/env/gym_wrappers.py
@@ -0,0 +1,81 @@
+from typing import Any, Dict, List, SupportsFloat, Tuple, Union
+
+import gymnasium as gym
+import numpy as np
+from packaging import version
+
+
+class ContinuousToDiscrete(gym.ActionWrapper):
+ """Gym environment wrapper to take discrete action in a continuous environment.
+
+ :param gym.Env env: gym environment with continuous action space.
+ :param int action_per_dim: number of discrete actions in each dimension
+ of the action space.
+ """
+
+ def __init__(self, env: gym.Env, action_per_dim: Union[int, List[int]]) -> None:
+ super().__init__(env)
+ assert isinstance(env.action_space, gym.spaces.Box)
+ low, high = env.action_space.low, env.action_space.high
+ if isinstance(action_per_dim, int):
+ action_per_dim = [action_per_dim] * env.action_space.shape[0]
+ assert len(action_per_dim) == env.action_space.shape[0]
+ self.action_space = gym.spaces.MultiDiscrete(action_per_dim)
+ self.mesh = np.array(
+ [np.linspace(lo, hi, a) for lo, hi, a in zip(low, high, action_per_dim)],
+ dtype=object
+ )
+
+ def action(self, act: np.ndarray) -> np.ndarray: # type: ignore
+ # modify act
+ assert len(act.shape) <= 2, f"Unknown action format with shape {act.shape}."
+ if len(act.shape) == 1:
+ return np.array([self.mesh[i][a] for i, a in enumerate(act)])
+ return np.array([[self.mesh[i][a] for i, a in enumerate(a_)] for a_ in act])
+
+
+class MultiDiscreteToDiscrete(gym.ActionWrapper):
+ """Gym environment wrapper to take discrete action in multidiscrete environment.
+
+ :param gym.Env env: gym environment with multidiscrete action space.
+ """
+
+ def __init__(self, env: gym.Env) -> None:
+ super().__init__(env)
+ assert isinstance(env.action_space, gym.spaces.MultiDiscrete)
+ nvec = env.action_space.nvec
+ assert nvec.ndim == 1
+ self.bases = np.ones_like(nvec)
+ for i in range(1, len(self.bases)):
+ self.bases[i] = self.bases[i - 1] * nvec[-i]
+ self.action_space = gym.spaces.Discrete(np.prod(nvec))
+
+ def action(self, act: np.ndarray) -> np.ndarray: # type: ignore
+ converted_act = []
+ for b in np.flip(self.bases):
+ converted_act.append(act // b)
+ act = act % b
+ return np.array(converted_act).transpose()
+
+
+class TruncatedAsTerminated(gym.Wrapper):
+ """A wrapper that set ``terminated = terminated or truncated`` for ``step()``.
+
+ It's intended to use with ``gym.wrappers.TimeLimit``.
+
+ :param gym.Env env: gym environment.
+ """
+
+ def __init__(self, env: gym.Env):
+ super().__init__(env)
+ if not version.parse(gym.__version__) >= version.parse('0.26.0'):
+ raise EnvironmentError(
+ f"TruncatedAsTerminated is not applicable with gym version \
+ {gym.__version__}"
+ )
+
+ def step(self,
+ act: np.ndarray) -> Tuple[Any, SupportsFloat, bool, bool, Dict[str, Any]]:
+ observation, reward, terminated, truncated, info = super().step(act)
+ terminated = (terminated or truncated)
+ return observation, reward, terminated, truncated, info
diff --git a/tianshou/env/pettingzoo_env.py b/tianshou/env/pettingzoo_env.py
new file mode 100644
index 0000000..9ab8280
--- /dev/null
+++ b/tianshou/env/pettingzoo_env.py
@@ -0,0 +1,131 @@
+import warnings
+from abc import ABC
+from typing import Any, Dict, List, Tuple
+
+import pettingzoo
+from gymnasium import spaces
+from packaging import version
+from pettingzoo.utils.env import AECEnv
+from pettingzoo.utils.wrappers import BaseWrapper
+
+if version.parse(pettingzoo.__version__) < version.parse("1.21.0"):
+ warnings.warn(
+ f"You are using PettingZoo {pettingzoo.__version__}. "
+ f"Future tianshou versions may not support PettingZoo<1.21.0. "
+ f"Consider upgrading your PettingZoo version.", DeprecationWarning
+ )
+
+
+class PettingZooEnv(AECEnv, ABC):
+ """The interface for petting zoo environments.
+
+ Multi-agent environments must be wrapped as
+ :class:`~tianshou.env.PettingZooEnv`. Here is the usage:
+ ::
+
+ env = PettingZooEnv(...)
+ # obs is a dict containing obs, agent_id, and mask
+ obs = env.reset()
+ action = policy(obs)
+ obs, rew, trunc, term, info = env.step(action)
+ env.close()
+
+ The available action's mask is set to True, otherwise it is set to False.
+ Further usage can be found at :ref:`marl_example`.
+ """
+
+ def __init__(self, env: BaseWrapper):
+ super().__init__()
+ self.env = env
+ # agent idx list
+ self.agents = self.env.possible_agents
+ self.agent_idx = {}
+ for i, agent_id in enumerate(self.agents):
+ self.agent_idx[agent_id] = i
+
+ self.rewards = [0] * len(self.agents)
+
+ # Get first observation space, assuming all agents have equal space
+ self.observation_space: Any = self.env.observation_space(self.agents[0])
+
+ # Get first action space, assuming all agents have equal space
+ self.action_space: Any = self.env.action_space(self.agents[0])
+
+ assert all(self.env.observation_space(agent) == self.observation_space
+ for agent in self.agents), \
+ "Observation spaces for all agents must be identical. Perhaps " \
+ "SuperSuit's pad_observations wrapper can help (useage: " \
+ "`supersuit.aec_wrappers.pad_observations(env)`"
+
+ assert all(self.env.action_space(agent) == self.action_space
+ for agent in self.agents), \
+ "Action spaces for all agents must be identical. Perhaps " \
+ "SuperSuit's pad_action_space wrapper can help (useage: " \
+ "`supersuit.aec_wrappers.pad_action_space(env)`"
+
+ self.reset()
+
+ def reset(self, *args: Any, **kwargs: Any) -> Tuple[dict, dict]:
+ self.env.reset(*args, **kwargs)
+
+ observation, reward, terminated, truncated, info = self.env.last(self)
+
+ if isinstance(observation, dict) and 'action_mask' in observation:
+ observation_dict = {
+ 'agent_id': self.env.agent_selection,
+ 'obs': observation['observation'],
+ 'mask':
+ [True if obm == 1 else False for obm in observation['action_mask']]
+ }
+ else:
+ if isinstance(self.action_space, spaces.Discrete):
+ observation_dict = {
+ 'agent_id': self.env.agent_selection,
+ 'obs': observation,
+ 'mask': [True] * self.env.action_space(self.env.agent_selection).n
+ }
+ else:
+ observation_dict = {
+ 'agent_id': self.env.agent_selection,
+ 'obs': observation,
+ }
+
+ return observation_dict, info
+
+ def step(self, action: Any) -> Tuple[Dict, List[int], bool, bool, Dict]:
+ self.env.step(action)
+
+ observation, rew, term, trunc, info = self.env.last()
+
+ if isinstance(observation, dict) and 'action_mask' in observation:
+ obs = {
+ 'agent_id': self.env.agent_selection,
+ 'obs': observation['observation'],
+ 'mask':
+ [True if obm == 1 else False for obm in observation['action_mask']]
+ }
+ else:
+ if isinstance(self.action_space, spaces.Discrete):
+ obs = {
+ 'agent_id': self.env.agent_selection,
+ 'obs': observation,
+ 'mask': [True] * self.env.action_space(self.env.agent_selection).n
+ }
+ else:
+ obs = {'agent_id': self.env.agent_selection, 'obs': observation}
+
+ for agent_id, reward in self.env.rewards.items():
+ self.rewards[self.agent_idx[agent_id]] = reward
+ return obs, self.rewards, term, trunc, info
+
+ def close(self) -> None:
+ self.env.close()
+
+ def seed(self, seed: Any = None) -> None:
+ try:
+ self.env.seed(seed)
+ except (NotImplementedError, AttributeError):
+ self.env.reset(seed=seed)
+
+ def render(self) -> Any:
+ return self.env.render()
diff --git a/tianshou/env/utils.py b/tianshou/env/utils.py
new file mode 100644
index 0000000..cbd36d9
--- /dev/null
+++ b/tianshou/env/utils.py
@@ -0,0 +1,31 @@
+from typing import TYPE_CHECKING, Any, Tuple, Union
+
+import cloudpickle
+import gymnasium
+import numpy as np
+
+try:
+ from tianshou.env.pettingzoo_env import PettingZooEnv
+except ImportError:
+ PettingZooEnv = None # type: ignore
+
+if TYPE_CHECKING:
+ import gym
+
+ENV_TYPE = Union[gymnasium.Env, "gym.Env", PettingZooEnv]
+
+gym_new_venv_step_type = Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray,
+ np.ndarray]
+
+
+class CloudpickleWrapper(object):
+ """A cloudpickle wrapper used in SubprocVectorEnv."""
+
+ def __init__(self, data: Any) -> None:
+ self.data = data
+
+ def __getstate__(self) -> str:
+ return cloudpickle.dumps(self.data)
+
+ def __setstate__(self, data: str) -> None:
+ self.data = cloudpickle.loads(data)
diff --git a/tianshou/env/venv_wrappers.py b/tianshou/env/venv_wrappers.py
new file mode 100644
index 0000000..6647028
--- /dev/null
+++ b/tianshou/env/venv_wrappers.py
@@ -0,0 +1,123 @@
+from typing import Any, List, Optional, Tuple, Union
+
+import numpy as np
+
+from tianshou.env.utils import gym_new_venv_step_type
+from tianshou.env.venvs import GYM_RESERVED_KEYS, BaseVectorEnv
+from tianshou.utils import RunningMeanStd
+
+
+class VectorEnvWrapper(BaseVectorEnv):
+ """Base class for vectorized environments wrapper."""
+
+ def __init__(self, venv: BaseVectorEnv) -> None:
+ self.venv = venv
+ self.is_async = venv.is_async
+
+ def __len__(self) -> int:
+ return len(self.venv)
+
+ def __getattribute__(self, key: str) -> Any:
+ if key in GYM_RESERVED_KEYS: # reserved keys in gym.Env
+ return getattr(self.venv, key)
+ else:
+ return super().__getattribute__(key)
+
+ def get_env_attr(
+ self,
+ key: str,
+ id: Optional[Union[int, List[int], np.ndarray]] = None,
+ ) -> List[Any]:
+ return self.venv.get_env_attr(key, id)
+
+ def set_env_attr(
+ self,
+ key: str,
+ value: Any,
+ id: Optional[Union[int, List[int], np.ndarray]] = None,
+ ) -> None:
+ return self.venv.set_env_attr(key, value, id)
+
+ def reset(
+ self,
+ id: Optional[Union[int, List[int], np.ndarray]] = None,
+ **kwargs: Any,
+ ) -> Tuple[np.ndarray, Union[dict, List[dict]]]:
+ return self.venv.reset(id, **kwargs)
+
+ def step(
+ self,
+ action: np.ndarray,
+ id: Optional[Union[int, List[int], np.ndarray]] = None,
+ ) -> gym_new_venv_step_type:
+ return self.venv.step(action, id)
+
+ def seed(
+ self,
+ seed: Optional[Union[int, List[int]]] = None,
+ ) -> List[Optional[List[int]]]:
+ return self.venv.seed(seed)
+
+ def render(self, **kwargs: Any) -> List[Any]:
+ return self.venv.render(**kwargs)
+
+ def close(self) -> None:
+ self.venv.close()
+
+
+class VectorEnvNormObs(VectorEnvWrapper):
+ """An observation normalization wrapper for vectorized environments.
+
+ :param bool update_obs_rms: whether to update obs_rms. Default to True.
+ """
+
+ def __init__(
+ self,
+ venv: BaseVectorEnv,
+ update_obs_rms: bool = True,
+ ) -> None:
+ super().__init__(venv)
+ # initialize observation running mean/std
+ self.update_obs_rms = update_obs_rms
+ self.obs_rms = RunningMeanStd()
+
+ def reset(
+ self,
+ id: Optional[Union[int, List[int], np.ndarray]] = None,
+ **kwargs: Any,
+ ) -> Tuple[np.ndarray, Union[dict, List[dict]]]:
+ obs, info = self.venv.reset(id, **kwargs)
+
+ if isinstance(obs, tuple): # type: ignore
+ raise TypeError(
+ "Tuple observation space is not supported. ",
+ "Please change it to array or dict space",
+ )
+
+ if self.obs_rms and self.update_obs_rms:
+ self.obs_rms.update(obs)
+ obs = self._norm_obs(obs)
+ return obs, info
+
+ def step(
+ self,
+ action: np.ndarray,
+ id: Optional[Union[int, List[int], np.ndarray]] = None,
+ ) -> gym_new_venv_step_type:
+ step_results = self.venv.step(action, id)
+ if self.obs_rms and self.update_obs_rms:
+ self.obs_rms.update(step_results[0])
+ return (self._norm_obs(step_results[0]), *step_results[1:])
+
+ def _norm_obs(self, obs: np.ndarray) -> np.ndarray:
+ if self.obs_rms:
+ return self.obs_rms.norm(obs) # type: ignore
+ return obs
+
+ def set_obs_rms(self, obs_rms: RunningMeanStd) -> None:
+ """Set with given observation running mean/std."""
+ self.obs_rms = obs_rms
+
+ def get_obs_rms(self) -> RunningMeanStd:
+ """Return observation running mean/std."""
+ return self.obs_rms
diff --git a/tianshou/env/venvs.py b/tianshou/env/venvs.py
new file mode 100644
index 0000000..cc39ba3
--- /dev/null
+++ b/tianshou/env/venvs.py
@@ -0,0 +1,499 @@
+import warnings
+from typing import Any, Callable, List, Optional, Tuple, Union
+
+import gymnasium as gym
+import numpy as np
+import packaging
+
+from tianshou.env.utils import ENV_TYPE, gym_new_venv_step_type
+from tianshou.env.worker import (
+ DummyEnvWorker,
+ EnvWorker,
+ RayEnvWorker,
+ SubprocEnvWorker,
+)
+
+try:
+ from tianshou.env.pettingzoo_env import PettingZooEnv
+except ImportError:
+ PettingZooEnv = None # type: ignore
+
+try:
+ import gym as old_gym
+
+ has_old_gym = True
+except ImportError:
+ has_old_gym = False
+
+GYM_RESERVED_KEYS = [
+ "metadata", "reward_range", "spec", "action_space", "observation_space"
+]
+
+
+def _patch_env_generator(fn: Callable[[], ENV_TYPE]) -> Callable[[], gym.Env]:
+ """Takes an environment generator and patches it to return Gymnasium envs.
+
+ This function takes the environment generator `fn` and returns a patched
+ generator, without invoking `fn`. The original generator may return
+ Gymnasium or OpenAI Gym environments, but the patched generator wraps
+ the result of `fn` in a shimmy wrapper to convert it to Gymnasium,
+ if necessary.
+ """
+
+ def patched() -> gym.Env:
+ assert callable(
+ fn
+ ), "Env generators that are provided to vector environemnts must be callable."
+
+ env = fn()
+ if isinstance(env, (gym.Env, PettingZooEnv)):
+ return env
+
+ if not has_old_gym or not isinstance(env, old_gym.Env):
+ raise ValueError(
+ f"Environment generator returned a {type(env)}, not a Gymnasium "
+ f"environment. In this case, we expect OpenAI Gym to be "
+ f"installed and the environment to be an OpenAI Gym environment."
+ )
+
+ try:
+ import shimmy
+ except ImportError as e:
+ raise ImportError(
+ "Missing shimmy installation. You provided an environment generator "
+ "that returned an OpenAI Gym environment. "
+ "Tianshou has transitioned to using Gymnasium internally. "
+ "In order to use OpenAI Gym environments with tianshou, you need to "
+ "install shimmy (`pip install shimmy`)."
+ ) from e
+
+ warnings.warn(
+ "You provided an environment generator that returned an OpenAI Gym "
+ "environment. We strongly recommend transitioning to Gymnasium "
+ "environments. "
+ "Tianshou is automatically wrapping your environments in a compatibility "
+ "layer, which could potentially cause issues."
+ )
+
+ gym_version = packaging.version.parse(old_gym.__version__)
+ if gym_version >= packaging.version.parse("0.26.0"):
+ return shimmy.GymV26CompatibilityV0(env=env)
+ elif gym_version >= packaging.version.parse("0.22.0"):
+ return shimmy.GymV22CompatibilityV0(env=env)
+ else:
+ raise Exception(
+ f"Found OpenAI Gym version {gym.__version__}. "
+ f"Tianshou only supports OpenAI Gym environments of version>=0.22.0"
+ )
+
+ return patched
+
+
+class BaseVectorEnv(object):
+ """Base class for vectorized environments.
+
+ Usage:
+ ::
+
+ env_num = 8
+ envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(env_num)])
+ assert len(envs) == env_num
+
+ It accepts a list of environment generators. In other words, an environment
+ generator ``efn`` of a specific task means that ``efn()`` returns the
+ environment of the given task, for example, ``gym.make(task)``.
+
+ All of the VectorEnv must inherit :class:`~tianshou.env.BaseVectorEnv`.
+ Here are some other usages:
+ ::
+
+ envs.seed(2) # which is equal to the next line
+ envs.seed([2, 3, 4, 5, 6, 7, 8, 9]) # set specific seed for each env
+ obs = envs.reset() # reset all environments
+ obs = envs.reset([0, 5, 7]) # reset 3 specific environments
+ obs, rew, done, info = envs.step([1] * 8) # step synchronously
+ envs.render() # render all environments
+ envs.close() # close all environments
+
+ .. warning::
+
+ If you use your own environment, please make sure the ``seed`` method
+ is set up properly, e.g.,
+ ::
+
+ def seed(self, seed):
+ np.random.seed(seed)
+
+ Otherwise, the outputs of these envs may be the same with each other.
+
+ :param env_fns: a list of callable envs, ``env_fns[i]()`` generates the i-th env.
+ :param worker_fn: a callable worker, ``worker_fn(env_fns[i])`` generates a
+ worker which contains the i-th env.
+ :param int wait_num: use in asynchronous simulation if the time cost of
+ ``env.step`` varies with time and synchronously waiting for all
+ environments to finish a step is time-wasting. In that case, we can
+ return when ``wait_num`` environments finish a step and keep on
+ simulation in these environments. If ``None``, asynchronous simulation
+ is disabled; else, ``1 <= wait_num <= env_num``.
+ :param float timeout: use in asynchronous simulation same as above, in each
+ vectorized step it only deal with those environments spending time
+ within ``timeout`` seconds.
+ """
+
+ def __init__(
+ self,
+ env_fns: List[Callable[[], ENV_TYPE]],
+ worker_fn: Callable[[Callable[[], gym.Env]], EnvWorker],
+ wait_num: Optional[int] = None,
+ timeout: Optional[float] = None,
+ ) -> None:
+ self._env_fns = env_fns
+ # A VectorEnv contains a pool of EnvWorkers, which corresponds to
+ # interact with the given envs (one worker <-> one env).
+ self.workers = [worker_fn(_patch_env_generator(fn)) for fn in env_fns]
+ self.worker_class = type(self.workers[0])
+ assert issubclass(self.worker_class, EnvWorker)
+ assert all([isinstance(w, self.worker_class) for w in self.workers])
+
+ self.env_num = len(env_fns)
+ self.wait_num = wait_num or len(env_fns)
+ assert (
+ 1 <= self.wait_num <= len(env_fns)
+ ), f"wait_num should be in [1, {len(env_fns)}], but got {wait_num}"
+ self.timeout = timeout
+ assert (
+ self.timeout is None or self.timeout > 0
+ ), f"timeout is {timeout}, it should be positive if provided!"
+ self.is_async = self.wait_num != len(env_fns) or timeout is not None
+ self.waiting_conn: List[EnvWorker] = []
+ # environments in self.ready_id is actually ready
+ # but environments in self.waiting_id are just waiting when checked,
+ # and they may be ready now, but this is not known until we check it
+ # in the step() function
+ self.waiting_id: List[int] = []
+ # all environments are ready in the beginning
+ self.ready_id = list(range(self.env_num))
+ self.is_closed = False
+
+ def _assert_is_not_closed(self) -> None:
+ assert (
+ not self.is_closed
+ ), f"Methods of {self.__class__.__name__} cannot be called after close."
+
+ def __len__(self) -> int:
+ """Return len(self), which is the number of environments."""
+ return self.env_num
+
+ def __getattribute__(self, key: str) -> Any:
+ """Switch the attribute getter depending on the key.
+
+ Any class who inherits ``gym.Env`` will inherit some attributes, like
+ ``action_space``. However, we would like the attribute lookup to go straight
+ into the worker (in fact, this vector env's action_space is always None).
+ """
+ if key in GYM_RESERVED_KEYS: # reserved keys in gym.Env
+ return self.get_env_attr(key)
+ else:
+ return super().__getattribute__(key)
+
+ def get_env_attr(
+ self,
+ key: str,
+ id: Optional[Union[int, List[int], np.ndarray]] = None,
+ ) -> List[Any]:
+ """Get an attribute from the underlying environments.
+
+ If id is an int, retrieve the attribute denoted by key from the environment
+ underlying the worker at index id. The result is returned as a list with one
+ element. Otherwise, retrieve the attribute for all workers at indices id and
+ return a list that is ordered correspondingly to id.
+
+ :param str key: The key of the desired attribute.
+ :param id: Indice(s) of the desired worker(s). Default to None for all env_id.
+
+ :return list: The list of environment attributes.
+ """
+ self._assert_is_not_closed()
+ id = self._wrap_id(id)
+ if self.is_async:
+ self._assert_id(id)
+
+ return [self.workers[j].get_env_attr(key) for j in id]
+
+ def set_env_attr(
+ self,
+ key: str,
+ value: Any,
+ id: Optional[Union[int, List[int], np.ndarray]] = None,
+ ) -> None:
+ """Set an attribute in the underlying environments.
+
+ If id is an int, set the attribute denoted by key from the environment
+ underlying the worker at index id to value.
+ Otherwise, set the attribute for all workers at indices id.
+
+ :param str key: The key of the desired attribute.
+ :param Any value: The new value of the attribute.
+ :param id: Indice(s) of the desired worker(s). Default to None for all env_id.
+ """
+ self._assert_is_not_closed()
+ id = self._wrap_id(id)
+ if self.is_async:
+ self._assert_id(id)
+ for j in id:
+ self.workers[j].set_env_attr(key, value)
+
+ def _wrap_id(
+ self,
+ id: Optional[Union[int, List[int], np.ndarray]] = None,
+ ) -> Union[List[int], np.ndarray]:
+ if id is None:
+ return list(range(self.env_num))
+ return [id] if np.isscalar(id) else id # type: ignore
+
+ def _assert_id(self, id: Union[List[int], np.ndarray]) -> None:
+ for i in id:
+ assert (
+ i not in self.waiting_id
+ ), f"Cannot interact with environment {i} which is stepping now."
+ assert (
+ i in self.ready_id
+ ), f"Can only interact with ready environments {self.ready_id}."
+
+ def reset(
+ self,
+ id: Optional[Union[int, List[int], np.ndarray]] = None,
+ **kwargs: Any,
+ ) -> Tuple[np.ndarray, Union[dict, List[dict]]]:
+ """Reset the state of some envs and return initial observations.
+
+ If id is None, reset the state of all the environments and return
+ initial observations, otherwise reset the specific environments with
+ the given id, either an int or a list.
+ """
+ self._assert_is_not_closed()
+ id = self._wrap_id(id)
+ if self.is_async:
+ self._assert_id(id)
+
+ # send(None) == reset() in worker
+ for i in id:
+ self.workers[i].send(None, **kwargs)
+ ret_list = [self.workers[i].recv() for i in id]
+
+ assert (
+ isinstance(ret_list[0], (tuple, list)) and len(ret_list[0]) == 2
+ and isinstance(ret_list[0][1], dict)
+ )
+
+ obs_list = [r[0] for r in ret_list]
+
+ if isinstance(obs_list[0], tuple): # type: ignore
+ raise TypeError(
+ "Tuple observation space is not supported. ",
+ "Please change it to array or dict space",
+ )
+ try:
+ obs = np.stack(obs_list)
+ except ValueError: # different len(obs)
+ obs = np.array(obs_list, dtype=object)
+
+ infos = [r[1] for r in ret_list]
+ return obs, infos # type: ignore
+
+ def step(
+ self,
+ action: np.ndarray,
+ id: Optional[Union[int, List[int], np.ndarray]] = None,
+ ) -> gym_new_venv_step_type:
+ """Run one timestep of some environments' dynamics.
+
+ If id is None, run one timestep of all the environments’ dynamics;
+ otherwise run one timestep for some environments with given id, either
+ an int or a list. When the end of episode is reached, you are
+ responsible for calling reset(id) to reset this environment’s state.
+
+ Accept a batch of action and return a tuple (batch_obs, batch_rew,
+ batch_done, batch_info) in numpy format.
+
+ :param numpy.ndarray action: a batch of action provided by the agent.
+
+ :return: A tuple consisting of either:
+
+ * ``obs`` a numpy.ndarray, the agent's observation of current environments
+ * ``rew`` a numpy.ndarray, the amount of rewards returned after \
+ previous actions
+ * ``terminated`` a numpy.ndarray, whether these episodes have been \
+ terminated
+ * ``truncated`` a numpy.ndarray, whether these episodes have been truncated
+ * ``info`` a numpy.ndarray, contains auxiliary diagnostic \
+ information (helpful for debugging, and sometimes learning)
+
+ For the async simulation:
+
+ Provide the given action to the environments. The action sequence
+ should correspond to the ``id`` argument, and the ``id`` argument
+ should be a subset of the ``env_id`` in the last returned ``info``
+ (initially they are env_ids of all the environments). If action is
+ None, fetch unfinished step() calls instead.
+ """
+ self._assert_is_not_closed()
+ id = self._wrap_id(id)
+ if not self.is_async:
+ assert len(action) == len(id)
+ for i, j in enumerate(id):
+ self.workers[j].send(action[i])
+ result = []
+ for j in id:
+ env_return = self.workers[j].recv()
+ env_return[-1]["env_id"] = j
+ result.append(env_return)
+ else:
+ if action is not None:
+ self._assert_id(id)
+ assert len(action) == len(id)
+ for act, env_id in zip(action, id):
+ self.workers[env_id].send(act)
+ self.waiting_conn.append(self.workers[env_id])
+ self.waiting_id.append(env_id)
+ self.ready_id = [x for x in self.ready_id if x not in id]
+ ready_conns: List[EnvWorker] = []
+ while not ready_conns:
+ ready_conns = self.worker_class.wait(
+ self.waiting_conn, self.wait_num, self.timeout
+ )
+ result = []
+ for conn in ready_conns:
+ waiting_index = self.waiting_conn.index(conn)
+ self.waiting_conn.pop(waiting_index)
+ env_id = self.waiting_id.pop(waiting_index)
+ # env_return can be (obs, reward, done, info) or
+ # (obs, reward, terminated, truncated, info)
+ env_return = conn.recv()
+ env_return[-1]["env_id"] = env_id # Add `env_id` to info
+ result.append(env_return)
+ self.ready_id.append(env_id)
+ obs_list, rew_list, term_list, trunc_list, info_list = tuple(zip(*result))
+ try:
+ obs_stack = np.stack(obs_list)
+ except ValueError: # different len(obs)
+ obs_stack = np.array(obs_list, dtype=object)
+ return (
+ obs_stack,
+ np.stack(rew_list),
+ np.stack(term_list),
+ np.stack(trunc_list),
+ np.stack(info_list),
+ )
+
+ def seed(
+ self,
+ seed: Optional[Union[int, List[int]]] = None,
+ ) -> List[Optional[List[int]]]:
+ """Set the seed for all environments.
+
+ Accept ``None``, an int (which will extend ``i`` to
+ ``[i, i + 1, i + 2, ...]``) or a list.
+
+ :return: The list of seeds used in this env's random number generators.
+ The first value in the list should be the "main" seed, or the value
+ which a reproducer pass to "seed".
+ """
+ self._assert_is_not_closed()
+ seed_list: Union[List[None], List[int]]
+ if seed is None:
+ seed_list = [seed] * self.env_num
+ elif isinstance(seed, int):
+ seed_list = [seed + i for i in range(self.env_num)]
+ else:
+ seed_list = seed
+ return [w.seed(s) for w, s in zip(self.workers, seed_list)]
+
+ def render(self, **kwargs: Any) -> List[Any]:
+ """Render all of the environments."""
+ self._assert_is_not_closed()
+ if self.is_async and len(self.waiting_id) > 0:
+ raise RuntimeError(
+ f"Environments {self.waiting_id} are still stepping, cannot "
+ "render them now."
+ )
+ return [w.render(**kwargs) for w in self.workers]
+
+ def close(self) -> None:
+ """Close all of the environments.
+
+ This function will be called only once (if not, it will be called during
+ garbage collected). This way, ``close`` of all workers can be assured.
+ """
+ self._assert_is_not_closed()
+ for w in self.workers:
+ w.close()
+ self.is_closed = True
+
+
+class DummyVectorEnv(BaseVectorEnv):
+ """Dummy vectorized environment wrapper, implemented in for-loop.
+
+ .. seealso::
+
+ Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.
+ """
+
+ def __init__(self, env_fns: List[Callable[[], ENV_TYPE]], **kwargs: Any) -> None:
+ super().__init__(env_fns, DummyEnvWorker, **kwargs)
+
+
+class SubprocVectorEnv(BaseVectorEnv):
+ """Vectorized environment wrapper based on subprocess.
+
+ .. seealso::
+
+ Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.
+ """
+
+ def __init__(self, env_fns: List[Callable[[], ENV_TYPE]], **kwargs: Any) -> None:
+
+ def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker:
+ return SubprocEnvWorker(fn, share_memory=False)
+
+ super().__init__(env_fns, worker_fn, **kwargs)
+
+
+class ShmemVectorEnv(BaseVectorEnv):
+ """Optimized SubprocVectorEnv with shared buffers to exchange observations.
+
+ ShmemVectorEnv has exactly the same API as SubprocVectorEnv.
+
+ .. seealso::
+
+ Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.
+ """
+
+ def __init__(self, env_fns: List[Callable[[], ENV_TYPE]], **kwargs: Any) -> None:
+
+ def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker:
+ return SubprocEnvWorker(fn, share_memory=True)
+
+ super().__init__(env_fns, worker_fn, **kwargs)
+
+
+class RayVectorEnv(BaseVectorEnv):
+ """Vectorized environment wrapper based on ray.
+
+ This is a choice to run distributed environments in a cluster.
+
+ .. seealso::
+
+ Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.
+ """
+
+ def __init__(self, env_fns: List[Callable[[], ENV_TYPE]], **kwargs: Any) -> None:
+ try:
+ import ray
+ except ImportError as exception:
+ raise ImportError(
+ "Please install ray to support RayVectorEnv: pip install ray"
+ ) from exception
+ if not ray.is_initialized():
+ ray.init()
+ super().__init__(env_fns, RayEnvWorker, **kwargs)
diff --git a/tianshou/env/worker/__init__.py b/tianshou/env/worker/__init__.py
new file mode 100644
index 0000000..1b1f375
--- /dev/null
+++ b/tianshou/env/worker/__init__.py
@@ -0,0 +1,11 @@
+from tianshou.env.worker.base import EnvWorker
+from tianshou.env.worker.dummy import DummyEnvWorker
+from tianshou.env.worker.ray import RayEnvWorker
+from tianshou.env.worker.subproc import SubprocEnvWorker
+
+__all__ = [
+ "EnvWorker",
+ "DummyEnvWorker",
+ "SubprocEnvWorker",
+ "RayEnvWorker",
+]
diff --git a/tianshou/env/worker/base.py b/tianshou/env/worker/base.py
new file mode 100644
index 0000000..773d56b
--- /dev/null
+++ b/tianshou/env/worker/base.py
@@ -0,0 +1,106 @@
+from abc import ABC, abstractmethod
+from typing import Any, Callable, List, Optional, Tuple, Union
+
+import gymnasium as gym
+import numpy as np
+
+from tianshou.env.utils import gym_new_venv_step_type
+from tianshou.utils import deprecation
+
+
+class EnvWorker(ABC):
+ """An abstract worker for an environment."""
+
+ def __init__(self, env_fn: Callable[[], gym.Env]) -> None:
+ self._env_fn = env_fn
+ self.is_closed = False
+ self.result: Union[gym_new_venv_step_type, Tuple[np.ndarray, dict]]
+ self.action_space = self.get_env_attr("action_space") # noqa: B009
+ self.is_reset = False
+
+ @abstractmethod
+ def get_env_attr(self, key: str) -> Any:
+ pass
+
+ @abstractmethod
+ def set_env_attr(self, key: str, value: Any) -> None:
+ pass
+
+ def send(self, action: Optional[np.ndarray]) -> None:
+ """Send action signal to low-level worker.
+
+ When action is None, it indicates sending "reset" signal; otherwise
+ it indicates "step" signal. The paired return value from "recv"
+ function is determined by such kind of different signal.
+ """
+ if hasattr(self, "send_action"):
+ deprecation(
+ "send_action will soon be deprecated. "
+ "Please use send and recv for your own EnvWorker."
+ )
+ if action is None:
+ self.is_reset = True
+ self.result = self.reset()
+ else:
+ self.is_reset = False
+ self.send_action(action)
+
+ def recv(
+ self
+ ) -> Union[gym_new_venv_step_type, Tuple[np.ndarray, dict], ]: # noqa:E125
+ """Receive result from low-level worker.
+
+ If the last "send" function sends a NULL action, it only returns a
+ single observation; otherwise it returns a tuple of (obs, rew, done,
+ info) or (obs, rew, terminated, truncated, info), based on whether
+ the environment is using the old step API or the new one.
+ """
+ if hasattr(self, "get_result"):
+ deprecation(
+ "get_result will soon be deprecated. "
+ "Please use send and recv for your own EnvWorker."
+ )
+ if not self.is_reset:
+ self.result = self.get_result()
+ return self.result
+
+ @abstractmethod
+ def reset(self, **kwargs: Any) -> Tuple[np.ndarray, dict]:
+ pass
+
+ def step(self, action: np.ndarray) -> gym_new_venv_step_type:
+ """Perform one timestep of the environment's dynamic.
+
+ "send" and "recv" are coupled in sync simulation, so users only call
+ "step" function. But they can be called separately in async
+ simulation, i.e. someone calls "send" first, and calls "recv" later.
+ """
+ self.send(action)
+ return self.recv() # type: ignore
+
+ @staticmethod
+ def wait(
+ workers: List["EnvWorker"],
+ wait_num: int,
+ timeout: Optional[float] = None
+ ) -> List["EnvWorker"]:
+ """Given a list of workers, return those ready ones."""
+ raise NotImplementedError
+
+ def seed(self, seed: Optional[int] = None) -> Optional[List[int]]:
+ return self.action_space.seed(seed) # issue 299
+
+ @abstractmethod
+ def render(self, **kwargs: Any) -> Any:
+ """Render the environment."""
+ pass
+
+ @abstractmethod
+ def close_env(self) -> None:
+ pass
+
+ def close(self) -> None:
+ if self.is_closed:
+ return None
+ self.is_closed = True
+ self.close_env()
diff --git a/tianshou/env/worker/dummy.py b/tianshou/env/worker/dummy.py
new file mode 100644
index 0000000..4eec4e0
--- /dev/null
+++ b/tianshou/env/worker/dummy.py
@@ -0,0 +1,52 @@
+from typing import Any, Callable, List, Optional, Tuple
+
+import gymnasium as gym
+import numpy as np
+
+from tianshou.env.worker import EnvWorker
+
+
+class DummyEnvWorker(EnvWorker):
+ """Dummy worker used in sequential vector environments."""
+
+ def __init__(self, env_fn: Callable[[], gym.Env]) -> None:
+ self.env = env_fn()
+ super().__init__(env_fn)
+
+ def get_env_attr(self, key: str) -> Any:
+ return getattr(self.env, key)
+
+ def set_env_attr(self, key: str, value: Any) -> None:
+ setattr(self.env.unwrapped, key, value)
+
+ def reset(self, **kwargs: Any) -> Tuple[np.ndarray, dict]:
+ if "seed" in kwargs:
+ super().seed(kwargs["seed"])
+ return self.env.reset(**kwargs)
+
+ @staticmethod
+ def wait( # type: ignore
+ workers: List["DummyEnvWorker"], wait_num: int, timeout: Optional[float] = None
+ ) -> List["DummyEnvWorker"]:
+ # Sequential EnvWorker objects are always ready
+ return workers
+
+ def send(self, action: Optional[np.ndarray], **kwargs: Any) -> None:
+ if action is None:
+ self.result = self.env.reset(**kwargs)
+ else:
+ self.result = self.env.step(action) # type: ignore
+
+ def seed(self, seed: Optional[int] = None) -> Optional[List[int]]:
+ super().seed(seed)
+ try:
+ return self.env.seed(seed) # type: ignore
+ except (AttributeError, NotImplementedError):
+ self.env.reset(seed=seed)
+ return [seed] # type: ignore
+
+ def render(self, **kwargs: Any) -> Any:
+ return self.env.render(**kwargs)
+
+ def close_env(self) -> None:
+ self.env.close()
diff --git a/tianshou/env/worker/ray.py b/tianshou/env/worker/ray.py
new file mode 100644
index 0000000..fe2b8fe
--- /dev/null
+++ b/tianshou/env/worker/ray.py
@@ -0,0 +1,74 @@
+from typing import Any, Callable, List, Optional
+
+import gymnasium as gym
+import numpy as np
+
+from tianshou.env.utils import gym_new_venv_step_type
+from tianshou.env.worker import EnvWorker
+
+try:
+ import ray
+except ImportError:
+ pass
+
+
+class _SetAttrWrapper(gym.Wrapper):
+
+ def set_env_attr(self, key: str, value: Any) -> None:
+ setattr(self.env.unwrapped, key, value)
+
+ def get_env_attr(self, key: str) -> Any:
+ return getattr(self.env, key)
+
+
+class RayEnvWorker(EnvWorker):
+ """Ray worker used in RayVectorEnv."""
+
+ def __init__(self, env_fn: Callable[[], gym.Env]) -> None:
+ self.env = ray.remote(_SetAttrWrapper).options( # type: ignore
+ num_cpus=0
+ ).remote(env_fn())
+ super().__init__(env_fn)
+
+ def get_env_attr(self, key: str) -> Any:
+ return ray.get(self.env.get_env_attr.remote(key))
+
+ def set_env_attr(self, key: str, value: Any) -> None:
+ ray.get(self.env.set_env_attr.remote(key, value))
+
+ def reset(self, **kwargs: Any) -> Any:
+ if "seed" in kwargs:
+ super().seed(kwargs["seed"])
+ return ray.get(self.env.reset.remote(**kwargs))
+
+ @staticmethod
+ def wait( # type: ignore
+ workers: List["RayEnvWorker"], wait_num: int, timeout: Optional[float] = None
+ ) -> List["RayEnvWorker"]:
+ results = [x.result for x in workers]
+ ready_results, _ = ray.wait(results, num_returns=wait_num, timeout=timeout)
+ return [workers[results.index(result)] for result in ready_results]
+
+ def send(self, action: Optional[np.ndarray], **kwargs: Any) -> None:
+ # self.result is actually a handle
+ if action is None:
+ self.result = self.env.reset.remote(**kwargs)
+ else:
+ self.result = self.env.step.remote(action)
+
+ def recv(self) -> gym_new_venv_step_type:
+ return ray.get(self.result) # type: ignore
+
+ def seed(self, seed: Optional[int] = None) -> Optional[List[int]]:
+ super().seed(seed)
+ try:
+ return ray.get(self.env.seed.remote(seed))
+ except (AttributeError, NotImplementedError):
+ self.env.reset.remote(seed=seed)
+ return None
+
+ def render(self, **kwargs: Any) -> Any:
+ return ray.get(self.env.render.remote(**kwargs))
+
+ def close_env(self) -> None:
+ ray.get(self.env.close.remote())
diff --git a/tianshou/env/worker/subproc.py b/tianshou/env/worker/subproc.py
new file mode 100644
index 0000000..68f34e6
--- /dev/null
+++ b/tianshou/env/worker/subproc.py
@@ -0,0 +1,256 @@
+import ctypes
+import time
+from collections import OrderedDict
+from multiprocessing import Array, Pipe, connection
+from multiprocessing.context import Process
+from typing import Any, Callable, List, Optional, Tuple, Union
+
+import gymnasium as gym
+import numpy as np
+
+from tianshou.env.utils import CloudpickleWrapper, gym_new_venv_step_type
+from tianshou.env.worker import EnvWorker
+
+_NP_TO_CT = {
+ np.bool_: ctypes.c_bool,
+ np.uint8: ctypes.c_uint8,
+ np.uint16: ctypes.c_uint16,
+ np.uint32: ctypes.c_uint32,
+ np.uint64: ctypes.c_uint64,
+ np.int8: ctypes.c_int8,
+ np.int16: ctypes.c_int16,
+ np.int32: ctypes.c_int32,
+ np.int64: ctypes.c_int64,
+ np.float32: ctypes.c_float,
+ np.float64: ctypes.c_double,
+}
+
+
+class ShArray:
+ """Wrapper of multiprocessing Array."""
+
+ def __init__(self, dtype: np.generic, shape: Tuple[int]) -> None:
+ self.arr = Array(_NP_TO_CT[dtype.type], int(np.prod(shape))) # type: ignore
+ self.dtype = dtype
+ self.shape = shape
+
+ def save(self, ndarray: np.ndarray) -> None:
+ assert isinstance(ndarray, np.ndarray)
+ dst = self.arr.get_obj()
+ dst_np = np.frombuffer(dst,
+ dtype=self.dtype).reshape(self.shape) # type: ignore
+ np.copyto(dst_np, ndarray)
+
+ def get(self) -> np.ndarray:
+ obj = self.arr.get_obj()
+ return np.frombuffer(obj, dtype=self.dtype).reshape(self.shape) # type: ignore
+
+
+def _setup_buf(space: gym.Space) -> Union[dict, tuple, ShArray]:
+ if isinstance(space, gym.spaces.Dict):
+ assert isinstance(space.spaces, OrderedDict)
+ return {k: _setup_buf(v) for k, v in space.spaces.items()}
+ elif isinstance(space, gym.spaces.Tuple):
+ assert isinstance(space.spaces, tuple)
+ return tuple([_setup_buf(t) for t in space.spaces])
+ else:
+ return ShArray(space.dtype, space.shape) # type: ignore
+
+
+def _worker(
+ parent: connection.Connection,
+ p: connection.Connection,
+ env_fn_wrapper: CloudpickleWrapper,
+ obs_bufs: Optional[Union[dict, tuple, ShArray]] = None,
+) -> None:
+
+ def _encode_obs(
+ obs: Union[dict, tuple, np.ndarray], buffer: Union[dict, tuple, ShArray]
+ ) -> None:
+ if isinstance(obs, np.ndarray) and isinstance(buffer, ShArray):
+ buffer.save(obs)
+ elif isinstance(obs, tuple) and isinstance(buffer, tuple):
+ for o, b in zip(obs, buffer):
+ _encode_obs(o, b)
+ elif isinstance(obs, dict) and isinstance(buffer, dict):
+ for k in obs.keys():
+ _encode_obs(obs[k], buffer[k])
+ return None
+
+ parent.close()
+ env = env_fn_wrapper.data()
+ try:
+ while True:
+ try:
+ cmd, data = p.recv()
+ except EOFError: # the pipe has been closed
+ p.close()
+ break
+ if cmd == "step":
+ env_return = env.step(data)
+ if obs_bufs is not None:
+ _encode_obs(env_return[0], obs_bufs)
+ env_return = (None, *env_return[1:])
+ p.send(env_return)
+ elif cmd == "reset":
+ obs, info = env.reset(**data)
+ if obs_bufs is not None:
+ _encode_obs(obs, obs_bufs)
+ obs = None
+ p.send((obs, info))
+ elif cmd == "close":
+ p.send(env.close())
+ p.close()
+ break
+ elif cmd == "render":
+ p.send(env.render(**data) if hasattr(env, "render") else None)
+ elif cmd == "seed":
+ if hasattr(env, "seed"):
+ p.send(env.seed(data))
+ else:
+ env.reset(seed=data)
+ p.send(None)
+ elif cmd == "getattr":
+ p.send(getattr(env, data) if hasattr(env, data) else None)
+ elif cmd == "setattr":
+ setattr(env.unwrapped, data["key"], data["value"])
+ else:
+ p.close()
+ raise NotImplementedError
+ except KeyboardInterrupt:
+ p.close()
+
+
+class SubprocEnvWorker(EnvWorker):
+ """Subprocess worker used in SubprocVectorEnv and ShmemVectorEnv."""
+
+ def __init__(
+ self, env_fn: Callable[[], gym.Env], share_memory: bool = False
+ ) -> None:
+ self.parent_remote, self.child_remote = Pipe()
+ self.share_memory = share_memory
+ self.buffer: Optional[Union[dict, tuple, ShArray]] = None
+ if self.share_memory:
+ dummy = env_fn()
+ obs_space = dummy.observation_space
+ dummy.close()
+ del dummy
+ self.buffer = _setup_buf(obs_space)
+ args = (
+ self.parent_remote,
+ self.child_remote,
+ CloudpickleWrapper(env_fn),
+ self.buffer,
+ )
+ self.process = Process(target=_worker, args=args, daemon=True)
+ self.process.start()
+ self.child_remote.close()
+ super().__init__(env_fn)
+
+ def get_env_attr(self, key: str) -> Any:
+ self.parent_remote.send(["getattr", key])
+ return self.parent_remote.recv()
+
+ def set_env_attr(self, key: str, value: Any) -> None:
+ self.parent_remote.send(["setattr", {"key": key, "value": value}])
+
+ def _decode_obs(self) -> Union[dict, tuple, np.ndarray]:
+
+ def decode_obs(
+ buffer: Optional[Union[dict, tuple, ShArray]]
+ ) -> Union[dict, tuple, np.ndarray]:
+ if isinstance(buffer, ShArray):
+ return buffer.get()
+ elif isinstance(buffer, tuple):
+ return tuple([decode_obs(b) for b in buffer])
+ elif isinstance(buffer, dict):
+ return {k: decode_obs(v) for k, v in buffer.items()}
+ else:
+ raise NotImplementedError
+
+ return decode_obs(self.buffer)
+
+ @staticmethod
+ def wait( # type: ignore
+ workers: List["SubprocEnvWorker"],
+ wait_num: int,
+ timeout: Optional[float] = None,
+ ) -> List["SubprocEnvWorker"]:
+ remain_conns = conns = [x.parent_remote for x in workers]
+ ready_conns: List[connection.Connection] = []
+ remain_time, t1 = timeout, time.time()
+ while len(remain_conns) > 0 and len(ready_conns) < wait_num:
+ if timeout:
+ remain_time = timeout - (time.time() - t1)
+ if remain_time <= 0:
+ break
+ # connection.wait hangs if the list is empty
+ new_ready_conns = connection.wait(remain_conns, timeout=remain_time)
+ ready_conns.extend(new_ready_conns) # type: ignore
+ remain_conns = [conn for conn in remain_conns if conn not in ready_conns]
+ return [workers[conns.index(con)] for con in ready_conns]
+
+ def send(self, action: Optional[np.ndarray], **kwargs: Any) -> None:
+ if action is None:
+ if "seed" in kwargs:
+ super().seed(kwargs["seed"])
+ self.parent_remote.send(["reset", kwargs])
+ else:
+ self.parent_remote.send(["step", action])
+
+ def recv(
+ self
+ ) -> Union[gym_new_venv_step_type, Tuple[np.ndarray, dict]]: # noqa:E125
+ result = self.parent_remote.recv()
+ if isinstance(result, tuple):
+ if len(result) == 2:
+ obs, info = result
+ if self.share_memory:
+ obs = self._decode_obs()
+ return obs, info
+ obs = result[0]
+ if self.share_memory:
+ obs = self._decode_obs()
+ return (obs, *result[1:]) # type: ignore
+ else:
+ obs = result
+ if self.share_memory:
+ obs = self._decode_obs()
+ return obs
+
+ def reset(self, **kwargs: Any) -> Tuple[np.ndarray, dict]:
+ if "seed" in kwargs:
+ super().seed(kwargs["seed"])
+ self.parent_remote.send(["reset", kwargs])
+
+ result = self.parent_remote.recv()
+ if isinstance(result, tuple):
+ obs, info = result
+ if self.share_memory:
+ obs = self._decode_obs()
+ return obs, info
+ else:
+ obs = result
+ if self.share_memory:
+ obs = self._decode_obs()
+ return obs
+
+ def seed(self, seed: Optional[int] = None) -> Optional[List[int]]:
+ super().seed(seed)
+ self.parent_remote.send(["seed", seed])
+ return self.parent_remote.recv()
+
+ def render(self, **kwargs: Any) -> Any:
+ self.parent_remote.send(["render", kwargs])
+ return self.parent_remote.recv()
+
+ def close_env(self) -> None:
+ try:
+ self.parent_remote.send(["close", None])
+ # mp may be deleted so it may raise AttributeError
+ self.parent_remote.recv()
+ self.process.join()
+ except (BrokenPipeError, EOFError, AttributeError):
+ pass
+ # ensure the subproc is terminated
+ self.process.terminate()
diff --git a/tianshou/exploration/__init__.py b/tianshou/exploration/__init__.py
new file mode 100644
index 0000000..0878d23
--- /dev/null
+++ b/tianshou/exploration/__init__.py
@@ -0,0 +1,7 @@
+from tianshou.exploration.random import BaseNoise, GaussianNoise, OUNoise
+
+__all__ = [
+ "BaseNoise",
+ "GaussianNoise",
+ "OUNoise",
+]
diff --git a/tianshou/exploration/random.py b/tianshou/exploration/random.py
new file mode 100644
index 0000000..25316e9
--- /dev/null
+++ b/tianshou/exploration/random.py
@@ -0,0 +1,85 @@
+from abc import ABC, abstractmethod
+from typing import Optional, Sequence, Union
+
+import numpy as np
+
+
+class BaseNoise(ABC, object):
+ """The action noise base class."""
+
+ def __init__(self) -> None:
+ super().__init__()
+
+ def reset(self) -> None:
+ """Reset to the initial state."""
+ pass
+
+ @abstractmethod
+ def __call__(self, size: Sequence[int]) -> np.ndarray:
+ """Generate new noise."""
+ raise NotImplementedError
+
+
+class GaussianNoise(BaseNoise):
+ """The vanilla Gaussian process, for exploration in DDPG by default."""
+
+ def __init__(self, mu: float = 0.0, sigma: float = 1.0) -> None:
+ super().__init__()
+ self._mu = mu
+ assert 0 <= sigma, "Noise std should not be negative."
+ self._sigma = sigma
+
+ def __call__(self, size: Sequence[int]) -> np.ndarray:
+ return np.random.normal(self._mu, self._sigma, size)
+
+
+class OUNoise(BaseNoise):
+ """Class for Ornstein-Uhlenbeck process, as used for exploration in DDPG.
+
+ Usage:
+ ::
+
+ # init
+ self.noise = OUNoise()
+ # generate noise
+ noise = self.noise(logits.shape, eps)
+
+ For required parameters, you can refer to the stackoverflow page. However,
+ our experiment result shows that (similar to OpenAI SpinningUp) using
+ vanilla Gaussian process has little difference from using the
+ Ornstein-Uhlenbeck process.
+ """
+
+ def __init__(
+ self,
+ mu: float = 0.0,
+ sigma: float = 0.3,
+ theta: float = 0.15,
+ dt: float = 1e-2,
+ x0: Optional[Union[float, np.ndarray]] = None,
+ ) -> None:
+ super().__init__()
+ self._mu = mu
+ self._alpha = theta * dt
+ self._beta = sigma * np.sqrt(dt)
+ self._x0 = x0
+ self.reset()
+
+ def reset(self) -> None:
+ """Reset to the initial state."""
+ self._x = self._x0
+
+ def __call__(self, size: Sequence[int], mu: Optional[float] = None) -> np.ndarray:
+ """Generate new noise.
+
+ Return an numpy array which size is equal to ``size``.
+ """
+ if self._x is None or isinstance(
+ self._x, np.ndarray
+ ) and self._x.shape != size:
+ self._x = 0.0
+ if mu is None:
+ mu = self._mu
+ r = self._beta * np.random.normal(size=size)
+ self._x = self._x + self._alpha * (mu - self._x) + r
+ return self._x # type: ignore
diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py
new file mode 100644
index 0000000..c8fa45e
--- /dev/null
+++ b/tianshou/policy/__init__.py
@@ -0,0 +1,66 @@
+"""Policy package."""
+# isort:skip_file
+
+from tianshou.policy.base import BasePolicy
+from tianshou.policy.random import RandomPolicy
+from tianshou.policy.modelfree.dqn import DQNPolicy
+from tianshou.policy.modelfree.bdq import BranchingDQNPolicy
+from tianshou.policy.modelfree.c51 import C51Policy
+from tianshou.policy.modelfree.rainbow import RainbowPolicy
+from tianshou.policy.modelfree.qrdqn import QRDQNPolicy
+from tianshou.policy.modelfree.iqn import IQNPolicy
+from tianshou.policy.modelfree.fqf import FQFPolicy
+from tianshou.policy.modelfree.pg import PGPolicy
+from tianshou.policy.modelfree.a2c import A2CPolicy
+from tianshou.policy.modelfree.npg import NPGPolicy
+from tianshou.policy.modelfree.ddpg import DDPGPolicy
+from tianshou.policy.modelfree.ppo import PPOPolicy
+from tianshou.policy.modelfree.trpo import TRPOPolicy
+from tianshou.policy.modelfree.td3 import TD3Policy
+from tianshou.policy.modelfree.sac import SACPolicy
+from tianshou.policy.modelfree.redq import REDQPolicy
+from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy
+from tianshou.policy.imitation.base import ImitationPolicy
+from tianshou.policy.imitation.bcq import BCQPolicy
+from tianshou.policy.imitation.cql import CQLPolicy
+from tianshou.policy.imitation.td3_bc import TD3BCPolicy
+from tianshou.policy.imitation.discrete_bcq import DiscreteBCQPolicy
+from tianshou.policy.imitation.discrete_cql import DiscreteCQLPolicy
+from tianshou.policy.imitation.discrete_crr import DiscreteCRRPolicy
+from tianshou.policy.imitation.gail import GAILPolicy
+from tianshou.policy.modelbased.psrl import PSRLPolicy
+from tianshou.policy.modelbased.icm import ICMPolicy
+from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager
+
+__all__ = [
+ "BasePolicy",
+ "RandomPolicy",
+ "DQNPolicy",
+ "BranchingDQNPolicy",
+ "C51Policy",
+ "RainbowPolicy",
+ "QRDQNPolicy",
+ "IQNPolicy",
+ "FQFPolicy",
+ "PGPolicy",
+ "A2CPolicy",
+ "NPGPolicy",
+ "DDPGPolicy",
+ "PPOPolicy",
+ "TRPOPolicy",
+ "TD3Policy",
+ "SACPolicy",
+ "REDQPolicy",
+ "DiscreteSACPolicy",
+ "ImitationPolicy",
+ "BCQPolicy",
+ "CQLPolicy",
+ "TD3BCPolicy",
+ "DiscreteBCQPolicy",
+ "DiscreteCQLPolicy",
+ "DiscreteCRRPolicy",
+ "GAILPolicy",
+ "PSRLPolicy",
+ "ICMPolicy",
+ "MultiAgentPolicyManager",
+]
diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py
new file mode 100644
index 0000000..584bb09
--- /dev/null
+++ b/tianshou/policy/base.py
@@ -0,0 +1,456 @@
+from abc import ABC, abstractmethod
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import gymnasium as gym
+import numpy as np
+import torch
+from gymnasium.spaces import Box, Discrete, MultiBinary, MultiDiscrete
+from numba import njit
+from torch import nn
+
+from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as
+from tianshou.utils import MultipleLRSchedulers
+
+
+class BasePolicy(ABC, nn.Module):
+ """The base class for any RL policy.
+
+ Tianshou aims to modularize RL algorithms. It comes into several classes of
+ policies in Tianshou. All of the policy classes must inherit
+ :class:`~tianshou.policy.BasePolicy`.
+
+ A policy class typically has the following parts:
+
+ * :meth:`~tianshou.policy.BasePolicy.__init__`: initialize the policy, including \
+ coping the target network and so on;
+ * :meth:`~tianshou.policy.BasePolicy.forward`: compute action with given \
+ observation;
+ * :meth:`~tianshou.policy.BasePolicy.process_fn`: pre-process data from the \
+ replay buffer (this function can interact with replay buffer);
+ * :meth:`~tianshou.policy.BasePolicy.learn`: update policy with a given batch of \
+ data.
+ * :meth:`~tianshou.policy.BasePolicy.post_process_fn`: update the replay buffer \
+ from the learning process (e.g., prioritized replay buffer needs to update \
+ the weight);
+ * :meth:`~tianshou.policy.BasePolicy.update`: the main interface for training, \
+ i.e., `process_fn -> learn -> post_process_fn`.
+
+ Most of the policy needs a neural network to predict the action and an
+ optimizer to optimize the policy. The rules of self-defined networks are:
+
+ 1. Input: observation "obs" (may be a ``numpy.ndarray``, a ``torch.Tensor``, a \
+ dict or any others), hidden state "state" (for RNN usage), and other information \
+ "info" provided by the environment.
+ 2. Output: some "logits", the next hidden state "state", and the intermediate \
+ result during policy forwarding procedure "policy". The "logits" could be a tuple \
+ instead of a ``torch.Tensor``. It depends on how the policy process the network \
+ output. For example, in PPO, the return of the network might be \
+ ``(mu, sigma), state`` for Gaussian policy. The "policy" can be a Batch of \
+ torch.Tensor or other things, which will be stored in the replay buffer, and can \
+ be accessed in the policy update process (e.g. in "policy.learn()", the \
+ "batch.policy" is what you need).
+
+ Since :class:`~tianshou.policy.BasePolicy` inherits ``torch.nn.Module``, you can
+ use :class:`~tianshou.policy.BasePolicy` almost the same as ``torch.nn.Module``,
+ for instance, loading and saving the model:
+ ::
+
+ torch.save(policy.state_dict(), "policy.pth")
+ policy.load_state_dict(torch.load("policy.pth"))
+ """
+
+ def __init__(
+ self,
+ observation_space: Optional[gym.Space] = None,
+ action_space: Optional[gym.Space] = None,
+ action_scaling: bool = False,
+ action_bound_method: str = "",
+ lr_scheduler: Optional[Union[torch.optim.lr_scheduler.LambdaLR,
+ MultipleLRSchedulers]] = None,
+ ) -> None:
+ super().__init__()
+ self.observation_space = observation_space
+ self.action_space = action_space
+ self.action_type = ""
+ if isinstance(action_space, (Discrete, MultiDiscrete, MultiBinary)):
+ self.action_type = "discrete"
+ elif isinstance(action_space, Box):
+ self.action_type = "continuous"
+ self.agent_id = 0
+ self.updating = False
+ self.action_scaling = action_scaling
+ # can be one of ("clip", "tanh", ""), empty string means no bounding
+ assert action_bound_method in ("", "clip", "tanh")
+ self.action_bound_method = action_bound_method
+ self.lr_scheduler = lr_scheduler
+ self._compile()
+
+ def set_agent_id(self, agent_id: int) -> None:
+ """Set self.agent_id = agent_id, for MARL."""
+ self.agent_id = agent_id
+
+ def exploration_noise(self, act: Union[np.ndarray, Batch],
+ batch: Batch) -> Union[np.ndarray, Batch]:
+ """Modify the action from policy.forward with exploration noise.
+
+ :param act: a data batch or numpy.ndarray which is the action taken by
+ policy.forward.
+ :param batch: the input batch for policy.forward, kept for advanced usage.
+
+ :return: action in the same form of input "act" but with added exploration
+ noise.
+ """
+ return act
+
+ def soft_update(self, tgt: nn.Module, src: nn.Module, tau: float) -> None:
+ """Softly update the parameters of target module towards the parameters \
+ of source module."""
+ for tgt_param, src_param in zip(tgt.parameters(), src.parameters()):
+ tgt_param.data.copy_(tau * src_param.data + (1 - tau) * tgt_param.data)
+
+ @abstractmethod
+ def forward(
+ self,
+ batch: Batch,
+ state: Optional[Union[dict, Batch, np.ndarray]] = None,
+ **kwargs: Any,
+ ) -> Batch:
+ """Compute action over the given batch data.
+
+ :return: A :class:`~tianshou.data.Batch` which MUST have the following keys:
+
+ * ``act`` an numpy.ndarray or a torch.Tensor, the action over \
+ given batch data.
+ * ``state`` a dict, an numpy.ndarray or a torch.Tensor, the \
+ internal state of the policy, ``None`` as default.
+
+ Other keys are user-defined. It depends on the algorithm. For example,
+ ::
+
+ # some code
+ return Batch(logits=..., act=..., state=None, dist=...)
+
+ The keyword ``policy`` is reserved and the corresponding data will be
+ stored into the replay buffer. For instance,
+ ::
+
+ # some code
+ return Batch(..., policy=Batch(log_prob=dist.log_prob(act)))
+ # and in the sampled data batch, you can directly use
+ # batch.policy.log_prob to get your data.
+
+ .. note::
+
+ In continuous action space, you should do another step "map_action" to get
+ the real action:
+ ::
+
+ act = policy(batch).act # doesn't map to the target action range
+ act = policy.map_action(act, batch)
+ """
+ pass
+
+ def map_action(self, act: Union[Batch, np.ndarray]) -> Union[Batch, np.ndarray]:
+ """Map raw network output to action range in gym's env.action_space.
+
+ This function is called in :meth:`~tianshou.data.Collector.collect` and only
+ affects action sending to env. Remapped action will not be stored in buffer
+ and thus can be viewed as a part of env (a black box action transformation).
+
+ Action mapping includes 2 standard procedures: bounding and scaling. Bounding
+ procedure expects original action range is (-inf, inf) and maps it to [-1, 1],
+ while scaling procedure expects original action range is (-1, 1) and maps it
+ to [action_space.low, action_space.high]. Bounding procedure is applied first.
+
+ :param act: a data batch or numpy.ndarray which is the action taken by
+ policy.forward.
+
+ :return: action in the same form of input "act" but remap to the target action
+ space.
+ """
+ if isinstance(self.action_space, gym.spaces.Box) and \
+ isinstance(act, np.ndarray):
+ # currently this action mapping only supports np.ndarray action
+ if self.action_bound_method == "clip":
+ act = np.clip(act, -1.0, 1.0)
+ elif self.action_bound_method == "tanh":
+ act = np.tanh(act)
+ if self.action_scaling:
+ assert np.min(act) >= -1.0 and np.max(act) <= 1.0, \
+ "action scaling only accepts raw action range = [-1, 1]"
+ low, high = self.action_space.low, self.action_space.high
+ act = low + (high - low) * (act + 1.0) / 2.0 # type: ignore
+ return act
+
+ def map_action_inverse(
+ self, act: Union[Batch, List, np.ndarray]
+ ) -> Union[Batch, List, np.ndarray]:
+ """Inverse operation to :meth:`~tianshou.policy.BasePolicy.map_action`.
+
+ This function is called in :meth:`~tianshou.data.Collector.collect` for
+ random initial steps. It scales [action_space.low, action_space.high] to
+ the value ranges of policy.forward.
+
+ :param act: a data batch, list or numpy.ndarray which is the action taken
+ by gym.spaces.Box.sample().
+
+ :return: action remapped.
+ """
+ if isinstance(self.action_space, gym.spaces.Box):
+ act = to_numpy(act)
+ if isinstance(act, np.ndarray):
+ if self.action_scaling:
+ low, high = self.action_space.low, self.action_space.high
+ scale = high - low
+ eps = np.finfo(np.float32).eps.item()
+ scale[scale < eps] += eps
+ act = (act - low) * 2.0 / scale - 1.0
+ if self.action_bound_method == "tanh":
+ act = (np.log(1.0 + act) - np.log(1.0 - act)) / 2.0 # type: ignore
+ return act
+
+ def process_fn(
+ self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
+ ) -> Batch:
+ """Pre-process the data from the provided replay buffer.
+
+ Used in :meth:`update`. Check out :ref:`process_fn` for more information.
+ """
+ return batch
+
+ @abstractmethod
+ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, Any]:
+ """Update policy with a given batch of data.
+
+ :return: A dict, including the data needed to be logged (e.g., loss).
+
+ .. note::
+
+ In order to distinguish the collecting state, updating state and
+ testing state, you can check the policy state by ``self.training``
+ and ``self.updating``. Please refer to :ref:`policy_state` for more
+ detailed explanation.
+
+ .. warning::
+
+ If you use ``torch.distributions.Normal`` and
+ ``torch.distributions.Categorical`` to calculate the log_prob,
+ please be careful about the shape: Categorical distribution gives
+ "[batch_size]" shape while Normal distribution gives "[batch_size,
+ 1]" shape. The auto-broadcasting of numerical operation with torch
+ tensors will amplify this error.
+ """
+ pass
+
+ def post_process_fn(
+ self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
+ ) -> None:
+ """Post-process the data from the provided replay buffer.
+
+ Typical usage is to update the sampling weight in prioritized
+ experience replay. Used in :meth:`update`.
+ """
+ if hasattr(buffer, "update_weight") and hasattr(batch, "weight"):
+ buffer.update_weight(indices, batch.weight)
+
+ def update(self, sample_size: int, buffer: Optional[ReplayBuffer],
+ **kwargs: Any) -> Dict[str, Any]:
+ """Update the policy network and replay buffer.
+
+ It includes 3 function steps: process_fn, learn, and post_process_fn. In
+ addition, this function will change the value of ``self.updating``: it will be
+ False before this function and will be True when executing :meth:`update`.
+ Please refer to :ref:`policy_state` for more detailed explanation.
+
+ :param int sample_size: 0 means it will extract all the data from the buffer,
+ otherwise it will sample a batch with given sample_size.
+ :param ReplayBuffer buffer: the corresponding replay buffer.
+
+ :return: A dict, including the data needed to be logged (e.g., loss) from
+ ``policy.learn()``.
+ """
+ if buffer is None:
+ return {}
+ batch, indices = buffer.sample(sample_size)
+ self.updating = True
+ batch = self.process_fn(batch, buffer, indices)
+ result = self.learn(batch, **kwargs)
+ self.post_process_fn(batch, buffer, indices)
+ if self.lr_scheduler is not None:
+ self.lr_scheduler.step()
+ self.updating = False
+ return result
+
+ @staticmethod
+ def value_mask(buffer: ReplayBuffer, indices: np.ndarray) -> np.ndarray:
+ """Value mask determines whether the obs_next of buffer[indices] is valid.
+
+ For instance, usually "obs_next" after "done" flag is considered to be invalid,
+ and its q/advantage value can provide meaningless (even misleading)
+ information, and should be set to 0 by hand. But if "done" flag is generated
+ because timelimit of game length (info["TimeLimit.truncated"] is set to True in
+ gym's settings), "obs_next" will instead be valid. Value mask is typically used
+ for assisting in calculating the correct q/advantage value.
+
+ :param ReplayBuffer buffer: the corresponding replay buffer.
+ :param numpy.ndarray indices: indices of replay buffer whose "obs_next" will be
+ judged.
+
+ :return: A bool type numpy.ndarray in the same shape with indices. "True" means
+ "obs_next" of that buffer[indices] is valid.
+ """
+ return ~buffer.terminated[indices]
+
+ @staticmethod
+ def compute_episodic_return(
+ batch: Batch,
+ buffer: ReplayBuffer,
+ indices: np.ndarray,
+ v_s_: Optional[Union[np.ndarray, torch.Tensor]] = None,
+ v_s: Optional[Union[np.ndarray, torch.Tensor]] = None,
+ gamma: float = 0.99,
+ gae_lambda: float = 0.95,
+ ) -> Tuple[np.ndarray, np.ndarray]:
+ """Compute returns over given batch.
+
+ Use Implementation of Generalized Advantage Estimator (arXiv:1506.02438)
+ to calculate q/advantage value of given batch.
+
+ :param Batch batch: a data batch which contains several episodes of data in
+ sequential order. Mind that the end of each finished episode of batch
+ should be marked by done flag, unfinished (or collecting) episodes will be
+ recognized by buffer.unfinished_index().
+ :param numpy.ndarray indices: tell batch's location in buffer, batch is equal
+ to buffer[indices].
+ :param np.ndarray v_s_: the value function of all next states :math:`V(s')`.
+ :param float gamma: the discount factor, should be in [0, 1]. Default to 0.99.
+ :param float gae_lambda: the parameter for Generalized Advantage Estimation,
+ should be in [0, 1]. Default to 0.95.
+
+ :return: two numpy arrays (returns, advantage) with each shape (bsz, ).
+ """
+ rew = batch.rew
+ if v_s_ is None:
+ assert np.isclose(gae_lambda, 1.0)
+ v_s_ = np.zeros_like(rew)
+ else:
+ v_s_ = to_numpy(v_s_.flatten())
+ v_s_ = v_s_ * BasePolicy.value_mask(buffer, indices)
+ v_s = np.roll(v_s_, 1) if v_s is None else to_numpy(v_s.flatten())
+
+ end_flag = np.logical_or(batch.terminated, batch.truncated)
+ end_flag[np.isin(indices, buffer.unfinished_index())] = True
+ advantage = _gae_return(v_s, v_s_, rew, end_flag, gamma, gae_lambda)
+ returns = advantage + v_s
+ # normalization varies from each policy, so we don't do it here
+ return returns, advantage
+
+ @staticmethod
+ def compute_nstep_return(
+ batch: Batch,
+ buffer: ReplayBuffer,
+ indice: np.ndarray,
+ target_q_fn: Callable[[ReplayBuffer, np.ndarray], torch.Tensor],
+ gamma: float = 0.99,
+ n_step: int = 1,
+ rew_norm: bool = False,
+ ) -> Batch:
+ r"""Compute n-step return for Q-learning targets.
+
+ .. math::
+ G_t = \sum_{i = t}^{t + n - 1} \gamma^{i - t}(1 - d_i)r_i +
+ \gamma^n (1 - d_{t + n}) Q_{\mathrm{target}}(s_{t + n})
+
+ where :math:`\gamma` is the discount factor, :math:`\gamma \in [0, 1]`,
+ :math:`d_t` is the done flag of step :math:`t`.
+
+ :param Batch batch: a data batch, which is equal to buffer[indice].
+ :param ReplayBuffer buffer: the data buffer.
+ :param function target_q_fn: a function which compute target Q value
+ of "obs_next" given data buffer and wanted indices.
+ :param float gamma: the discount factor, should be in [0, 1]. Default to 0.99.
+ :param int n_step: the number of estimation step, should be an int greater
+ than 0. Default to 1.
+ :param bool rew_norm: normalize the reward to Normal(0, 1), Default to False.
+
+ :return: a Batch. The result will be stored in batch.returns as a
+ torch.Tensor with the same shape as target_q_fn's return tensor.
+ """
+ assert not rew_norm, \
+ "Reward normalization in computing n-step returns is unsupported now."
+ rew = buffer.rew
+ bsz = len(indice)
+ indices = [indice]
+ for _ in range(n_step - 1):
+ indices.append(buffer.next(indices[-1]))
+ indices = np.stack(indices)
+ # terminal indicates buffer indexes nstep after 'indice',
+ # and are truncated at the end of each episode
+ terminal = indices[-1]
+ with torch.no_grad():
+ target_q_torch = target_q_fn(buffer, terminal) # (bsz, ?)
+ target_q = to_numpy(target_q_torch.reshape(bsz, -1))
+ target_q = target_q * BasePolicy.value_mask(buffer, terminal).reshape(-1, 1)
+ end_flag = buffer.done.copy()
+ end_flag[buffer.unfinished_index()] = True
+ target_q = _nstep_return(rew, end_flag, target_q, indices, gamma, n_step)
+
+ batch.returns = to_torch_as(target_q, target_q_torch)
+ if hasattr(batch, "weight"): # prio buffer update
+ batch.weight = to_torch_as(batch.weight, target_q_torch)
+ return batch
+
+ def _compile(self) -> None:
+ f64 = np.array([0, 1], dtype=np.float64)
+ f32 = np.array([0, 1], dtype=np.float32)
+ b = np.array([False, True], dtype=np.bool_)
+ i64 = np.array([[0, 1]], dtype=np.int64)
+ _gae_return(f64, f64, f64, b, 0.1, 0.1)
+ _gae_return(f32, f32, f64, b, 0.1, 0.1)
+ _nstep_return(f64, b, f32.reshape(-1, 1), i64, 0.1, 1)
+
+
+@njit
+def _gae_return(
+ v_s: np.ndarray,
+ v_s_: np.ndarray,
+ rew: np.ndarray,
+ end_flag: np.ndarray,
+ gamma: float,
+ gae_lambda: float,
+) -> np.ndarray:
+ returns = np.zeros(rew.shape)
+ delta = rew + v_s_ * gamma - v_s
+ discount = (1.0 - end_flag) * (gamma * gae_lambda)
+ gae = 0.0
+ for i in range(len(rew) - 1, -1, -1):
+ gae = delta[i] + discount[i] * gae
+ returns[i] = gae
+ return returns
+
+
+@njit
+def _nstep_return(
+ rew: np.ndarray,
+ end_flag: np.ndarray,
+ target_q: np.ndarray,
+ indices: np.ndarray,
+ gamma: float,
+ n_step: int,
+) -> np.ndarray:
+ gamma_buffer = np.ones(n_step + 1)
+ for i in range(1, n_step + 1):
+ gamma_buffer[i] = gamma_buffer[i - 1] * gamma
+ target_shape = target_q.shape
+ bsz = target_shape[0]
+ # change target_q to 2d array
+ target_q = target_q.reshape(bsz, -1)
+ returns = np.zeros(target_q.shape)
+ gammas = np.full(indices[0].shape, n_step)
+ for n in range(n_step - 1, -1, -1):
+ now = indices[n]
+ gammas[end_flag[now] > 0] = n + 1
+ returns[end_flag[now] > 0] = 0.0
+ returns = rew[now].reshape(bsz, 1) + gamma * returns
+ target_q = target_q * gamma_buffer[gammas].reshape(bsz, 1) + returns
+ return target_q.reshape(target_shape)
diff --git a/tianshou/policy/imitation/__init__.py b/tianshou/policy/imitation/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tianshou/policy/imitation/base.py b/tianshou/policy/imitation/base.py
new file mode 100644
index 0000000..211c093
--- /dev/null
+++ b/tianshou/policy/imitation/base.py
@@ -0,0 +1,64 @@
+from typing import Any, Dict, Optional, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from tianshou.data import Batch, to_torch
+from tianshou.policy import BasePolicy
+
+
+class ImitationPolicy(BasePolicy):
+ """Implementation of vanilla imitation learning.
+
+ :param torch.nn.Module model: a model following the rules in
+ :class:`~tianshou.policy.BasePolicy`. (s -> a)
+ :param torch.optim.Optimizer optim: for optimizing the model.
+ :param gym.Space action_space: env's action space.
+ :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
+ optimizer in each policy.update(). Default to None (no lr_scheduler).
+
+ .. seealso::
+
+ Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
+ explanation.
+ """
+
+ def __init__(
+ self,
+ model: torch.nn.Module,
+ optim: torch.optim.Optimizer,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.model = model
+ self.optim = optim
+ assert self.action_type in ["continuous", "discrete"], \
+ "Please specify action_space."
+
+ def forward(
+ self,
+ batch: Batch,
+ state: Optional[Union[dict, Batch, np.ndarray]] = None,
+ **kwargs: Any,
+ ) -> Batch:
+ logits, hidden = self.model(batch.obs, state=state, info=batch.info)
+ if self.action_type == "discrete":
+ act = logits.max(dim=1)[1]
+ else:
+ act = logits
+ return Batch(logits=logits, act=act, state=hidden)
+
+ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
+ self.optim.zero_grad()
+ if self.action_type == "continuous": # regression
+ act = self(batch).act
+ act_target = to_torch(batch.act, dtype=torch.float32, device=act.device)
+ loss = F.mse_loss(act, act_target)
+ elif self.action_type == "discrete": # classification
+ act = F.log_softmax(self(batch).logits, dim=-1)
+ act_target = to_torch(batch.act, dtype=torch.long, device=act.device)
+ loss = F.nll_loss(act, act_target)
+ loss.backward()
+ self.optim.step()
+ return {"loss": loss.item()}
diff --git a/tianshou/policy/imitation/bcq.py b/tianshou/policy/imitation/bcq.py
new file mode 100644
index 0000000..883400d
--- /dev/null
+++ b/tianshou/policy/imitation/bcq.py
@@ -0,0 +1,206 @@
+import copy
+from typing import Any, Dict, Optional, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from tianshou.data import Batch, to_torch
+from tianshou.policy import BasePolicy
+from tianshou.utils.net.continuous import VAE
+
+
+class BCQPolicy(BasePolicy):
+ """Implementation of BCQ algorithm. arXiv:1812.02900.
+
+ :param Perturbation actor: the actor perturbation. (s, a -> perturbed a)
+ :param torch.optim.Optimizer actor_optim: the optimizer for actor network.
+ :param torch.nn.Module critic1: the first critic network. (s, a -> Q(s, a))
+ :param torch.optim.Optimizer critic1_optim: the optimizer for the first
+ critic network.
+ :param torch.nn.Module critic2: the second critic network. (s, a -> Q(s, a))
+ :param torch.optim.Optimizer critic2_optim: the optimizer for the second
+ critic network.
+ :param VAE vae: the VAE network, generating actions similar
+ to those in batch. (s, a -> generated a)
+ :param torch.optim.Optimizer vae_optim: the optimizer for the VAE network.
+ :param Union[str, torch.device] device: which device to create this model on.
+ Default to "cpu".
+ :param float gamma: discount factor, in [0, 1]. Default to 0.99.
+ :param float tau: param for soft update of the target network.
+ Default to 0.005.
+ :param float lmbda: param for Clipped Double Q-learning. Default to 0.75.
+ :param int forward_sampled_times: the number of sampled actions in forward
+ function. The policy samples many actions and takes the action with the
+ max value. Default to 100.
+ :param int num_sampled_action: the number of sampled actions in calculating
+ target Q. The algorithm samples several actions using VAE, and perturbs
+ each action to get the target Q. Default to 10.
+ :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
+ optimizer in each policy.update(). Default to None (no lr_scheduler).
+
+ .. seealso::
+
+ Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
+ explanation.
+ """
+
+ def __init__(
+ self,
+ actor: torch.nn.Module,
+ actor_optim: torch.optim.Optimizer,
+ critic1: torch.nn.Module,
+ critic1_optim: torch.optim.Optimizer,
+ critic2: torch.nn.Module,
+ critic2_optim: torch.optim.Optimizer,
+ vae: VAE,
+ vae_optim: torch.optim.Optimizer,
+ device: Union[str, torch.device] = "cpu",
+ gamma: float = 0.99,
+ tau: float = 0.005,
+ lmbda: float = 0.75,
+ forward_sampled_times: int = 100,
+ num_sampled_action: int = 10,
+ **kwargs: Any
+ ) -> None:
+ # actor is Perturbation!
+ super().__init__(**kwargs)
+ self.actor = actor
+ self.actor_target = copy.deepcopy(self.actor)
+ self.actor_optim = actor_optim
+
+ self.critic1 = critic1
+ self.critic1_target = copy.deepcopy(self.critic1)
+ self.critic1_optim = critic1_optim
+
+ self.critic2 = critic2
+ self.critic2_target = copy.deepcopy(self.critic2)
+ self.critic2_optim = critic2_optim
+
+ self.vae = vae
+ self.vae_optim = vae_optim
+
+ self.gamma = gamma
+ self.tau = tau
+ self.lmbda = lmbda
+ self.device = device
+ self.forward_sampled_times = forward_sampled_times
+ self.num_sampled_action = num_sampled_action
+
+ def train(self, mode: bool = True) -> "BCQPolicy":
+ """Set the module in training mode, except for the target network."""
+ self.training = mode
+ self.actor.train(mode)
+ self.critic1.train(mode)
+ self.critic2.train(mode)
+ return self
+
+ def forward(
+ self,
+ batch: Batch,
+ state: Optional[Union[dict, Batch, np.ndarray]] = None,
+ **kwargs: Any,
+ ) -> Batch:
+ """Compute action over the given batch data."""
+ # There is "obs" in the Batch
+ # obs_group: several groups. Each group has a state.
+ obs_group: torch.Tensor = to_torch(batch.obs, device=self.device)
+ act_group = []
+ for obs in obs_group:
+ # now obs is (state_dim)
+ obs = (obs.reshape(1, -1)).repeat(self.forward_sampled_times, 1)
+ # now obs is (forward_sampled_times, state_dim)
+
+ # decode(obs) generates action and actor perturbs it
+ act = self.actor(obs, self.vae.decode(obs))
+ # now action is (forward_sampled_times, action_dim)
+ q1 = self.critic1(obs, act)
+ # q1 is (forward_sampled_times, 1)
+ max_indice = q1.argmax(0)
+ act_group.append(act[max_indice].cpu().data.numpy().flatten())
+ act_group = np.array(act_group)
+ return Batch(act=act_group)
+
+ def sync_weight(self) -> None:
+ """Soft-update the weight for the target network."""
+ self.soft_update(self.critic1_target, self.critic1, self.tau)
+ self.soft_update(self.critic2_target, self.critic2, self.tau)
+ self.soft_update(self.actor_target, self.actor, self.tau)
+
+ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
+ # batch: obs, act, rew, done, obs_next. (numpy array)
+ # (batch_size, state_dim)
+ batch: Batch = to_torch(batch, dtype=torch.float, device=self.device)
+ obs, act = batch.obs, batch.act
+ batch_size = obs.shape[0]
+
+ # mean, std: (state.shape[0], latent_dim)
+ recon, mean, std = self.vae(obs, act)
+ recon_loss = F.mse_loss(act, recon)
+ # (....) is D_KL( N(mu, sigma) || N(0,1) )
+ KL_loss = (-torch.log(std) + (std.pow(2) + mean.pow(2) - 1) / 2).mean()
+ vae_loss = recon_loss + KL_loss / 2
+
+ self.vae_optim.zero_grad()
+ vae_loss.backward()
+ self.vae_optim.step()
+
+ # critic training:
+ with torch.no_grad():
+ # repeat num_sampled_action times
+ obs_next = batch.obs_next.repeat_interleave(self.num_sampled_action, dim=0)
+ # now obs_next: (num_sampled_action * batch_size, state_dim)
+
+ # perturbed action generated by VAE
+ act_next = self.vae.decode(obs_next)
+ # now obs_next: (num_sampled_action * batch_size, action_dim)
+ target_Q1 = self.critic1_target(obs_next, act_next)
+ target_Q2 = self.critic2_target(obs_next, act_next)
+
+ # Clipped Double Q-learning
+ target_Q = \
+ self.lmbda * torch.min(target_Q1, target_Q2) + \
+ (1 - self.lmbda) * torch.max(target_Q1, target_Q2)
+ # now target_Q: (num_sampled_action * batch_size, 1)
+
+ # the max value of Q
+ target_Q = target_Q.reshape(batch_size, -1).max(dim=1)[0].reshape(-1, 1)
+ # now target_Q: (batch_size, 1)
+
+ target_Q = \
+ batch.rew.reshape(-1, 1) + \
+ (1 - batch.done).reshape(-1, 1) * self.gamma * target_Q
+
+ current_Q1 = self.critic1(obs, act)
+ current_Q2 = self.critic2(obs, act)
+
+ critic1_loss = F.mse_loss(current_Q1, target_Q)
+ critic2_loss = F.mse_loss(current_Q2, target_Q)
+
+ self.critic1_optim.zero_grad()
+ self.critic2_optim.zero_grad()
+ critic1_loss.backward()
+ critic2_loss.backward()
+ self.critic1_optim.step()
+ self.critic2_optim.step()
+
+ sampled_act = self.vae.decode(obs)
+ perturbed_act = self.actor(obs, sampled_act)
+
+ # max
+ actor_loss = -self.critic1(obs, perturbed_act).mean()
+
+ self.actor_optim.zero_grad()
+ actor_loss.backward()
+ self.actor_optim.step()
+
+ # update target network
+ self.sync_weight()
+
+ result = {
+ "loss/actor": actor_loss.item(),
+ "loss/critic1": critic1_loss.item(),
+ "loss/critic2": critic2_loss.item(),
+ "loss/vae": vae_loss.item(),
+ }
+ return result
diff --git a/tianshou/policy/imitation/cql.py b/tianshou/policy/imitation/cql.py
new file mode 100644
index 0000000..7d9cea9
--- /dev/null
+++ b/tianshou/policy/imitation/cql.py
@@ -0,0 +1,288 @@
+from typing import Any, Dict, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch.nn.utils import clip_grad_norm_
+
+from tianshou.data import Batch, ReplayBuffer, to_torch
+from tianshou.policy import SACPolicy
+from tianshou.utils.net.continuous import ActorProb
+
+
+class CQLPolicy(SACPolicy):
+ """Implementation of CQL algorithm. arXiv:2006.04779.
+
+ :param ActorProb actor: the actor network following the rules in
+ :class:`~tianshou.policy.BasePolicy`. (s -> a)
+ :param torch.optim.Optimizer actor_optim: the optimizer for actor network.
+ :param torch.nn.Module critic1: the first critic network. (s, a -> Q(s, a))
+ :param torch.optim.Optimizer critic1_optim: the optimizer for the first
+ critic network.
+ :param torch.nn.Module critic2: the second critic network. (s, a -> Q(s, a))
+ :param torch.optim.Optimizer critic2_optim: the optimizer for the second
+ critic network.
+ :param float cql_alpha_lr: the learning rate of cql_log_alpha. Default to 1e-4.
+ :param float cql_weight: the value of alpha. Default to 1.0.
+ :param float tau: param for soft update of the target network.
+ Default to 0.005.
+ :param float gamma: discount factor, in [0, 1]. Default to 0.99.
+ :param (float, torch.Tensor, torch.optim.Optimizer) or float alpha: entropy
+ regularization coefficient. Default to 0.2.
+ If a tuple (target_entropy, log_alpha, alpha_optim) is provided, then
+ alpha is automatically tuned.
+ :param float temperature: the value of temperature. Default to 1.0.
+ :param bool with_lagrange: whether to use Lagrange. Default to True.
+ :param float lagrange_threshold: the value of tau in CQL(Lagrange).
+ Default to 10.0.
+ :param float min_action: The minimum value of each dimension of action.
+ Default to -1.0.
+ :param float max_action: The maximum value of each dimension of action.
+ Default to 1.0.
+ :param int num_repeat_actions: The number of times the action is repeated
+ when calculating log-sum-exp. Default to 10.
+ :param float alpha_min: lower bound for clipping cql_alpha. Default to 0.0.
+ :param float alpha_max: upper bound for clipping cql_alpha. Default to 1e6.
+ :param float clip_grad: clip_grad for updating critic network. Default to 1.0.
+ :param Union[str, torch.device] device: which device to create this model on.
+ Default to "cpu".
+ :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
+ optimizer in each policy.update(). Default to None (no lr_scheduler).
+
+ .. seealso::
+
+ Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
+ explanation.
+ """
+
+ def __init__(
+ self,
+ actor: ActorProb,
+ actor_optim: torch.optim.Optimizer,
+ critic1: torch.nn.Module,
+ critic1_optim: torch.optim.Optimizer,
+ critic2: torch.nn.Module,
+ critic2_optim: torch.optim.Optimizer,
+ cql_alpha_lr: float = 1e-4,
+ cql_weight: float = 1.0,
+ tau: float = 0.005,
+ gamma: float = 0.99,
+ alpha: Union[float, Tuple[float, torch.Tensor, torch.optim.Optimizer]] = 0.2,
+ temperature: float = 1.0,
+ with_lagrange: bool = True,
+ lagrange_threshold: float = 10.0,
+ min_action: float = -1.0,
+ max_action: float = 1.0,
+ num_repeat_actions: int = 10,
+ alpha_min: float = 0.0,
+ alpha_max: float = 1e6,
+ clip_grad: float = 1.0,
+ device: Union[str, torch.device] = "cpu",
+ **kwargs: Any
+ ) -> None:
+ super().__init__(
+ actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, tau,
+ gamma, alpha, **kwargs
+ )
+ # There are _target_entropy, _log_alpha, _alpha_optim in SACPolicy.
+ self.device = device
+ self.temperature = temperature
+ self.with_lagrange = with_lagrange
+ self.lagrange_threshold = lagrange_threshold
+
+ self.cql_weight = cql_weight
+
+ self.cql_log_alpha = torch.tensor([0.0], requires_grad=True)
+ self.cql_alpha_optim = torch.optim.Adam([self.cql_log_alpha], lr=cql_alpha_lr)
+ self.cql_log_alpha = self.cql_log_alpha.to(device)
+
+ self.min_action = min_action
+ self.max_action = max_action
+
+ self.num_repeat_actions = num_repeat_actions
+
+ self.alpha_min = alpha_min
+ self.alpha_max = alpha_max
+ self.clip_grad = clip_grad
+
+ def train(self, mode: bool = True) -> "CQLPolicy":
+ """Set the module in training mode, except for the target network."""
+ self.training = mode
+ self.actor.train(mode)
+ self.critic1.train(mode)
+ self.critic2.train(mode)
+ return self
+
+ def sync_weight(self) -> None:
+ """Soft-update the weight for the target network."""
+ self.soft_update(self.critic1_old, self.critic1, self.tau)
+ self.soft_update(self.critic2_old, self.critic2, self.tau)
+
+ def actor_pred(self, obs: torch.Tensor) -> \
+ Tuple[torch.Tensor, torch.Tensor]:
+ batch = Batch(obs=obs, info=None)
+ obs_result = self(batch)
+ return obs_result.act, obs_result.log_prob
+
+ def calc_actor_loss(self, obs: torch.Tensor) -> \
+ Tuple[torch.Tensor, torch.Tensor]:
+ act_pred, log_pi = self.actor_pred(obs)
+ q1 = self.critic1(obs, act_pred)
+ q2 = self.critic2(obs, act_pred)
+ min_Q = torch.min(q1, q2)
+ self._alpha: Union[float, torch.Tensor]
+ actor_loss = (self._alpha * log_pi - min_Q).mean()
+ # actor_loss.shape: (), log_pi.shape: (batch_size, 1)
+ return actor_loss, log_pi
+
+ def calc_pi_values(self, obs_pi: torch.Tensor, obs_to_pred: torch.Tensor) -> \
+ Tuple[torch.Tensor, torch.Tensor]:
+ act_pred, log_pi = self.actor_pred(obs_pi)
+
+ q1 = self.critic1(obs_to_pred, act_pred)
+ q2 = self.critic2(obs_to_pred, act_pred)
+
+ return q1 - log_pi.detach(), q2 - log_pi.detach()
+
+ def calc_random_values(self, obs: torch.Tensor, act: torch.Tensor) -> \
+ Tuple[torch.Tensor, torch.Tensor]:
+ random_value1 = self.critic1(obs, act)
+ random_log_prob1 = np.log(0.5**act.shape[-1])
+
+ random_value2 = self.critic2(obs, act)
+ random_log_prob2 = np.log(0.5**act.shape[-1])
+
+ return random_value1 - random_log_prob1, random_value2 - random_log_prob2
+
+ def process_fn(
+ self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
+ ) -> Batch:
+ return batch
+
+ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
+ batch: Batch = to_torch(batch, dtype=torch.float, device=self.device)
+ obs, act, rew, obs_next = batch.obs, batch.act, batch.rew, batch.obs_next
+ batch_size = obs.shape[0]
+
+ # compute actor loss and update actor
+ actor_loss, log_pi = self.calc_actor_loss(obs)
+ self.actor_optim.zero_grad()
+ actor_loss.backward()
+ self.actor_optim.step()
+
+ # compute alpha loss
+ if self._is_auto_alpha:
+ log_pi = log_pi + self._target_entropy
+ alpha_loss = -(self._log_alpha * log_pi.detach()).mean()
+ self._alpha_optim.zero_grad()
+ # update log_alpha
+ alpha_loss.backward()
+ self._alpha_optim.step()
+ # update alpha
+ self._alpha = self._log_alpha.detach().exp()
+
+ # compute target_Q
+ with torch.no_grad():
+ act_next, new_log_pi = self.actor_pred(obs_next)
+
+ target_Q1 = self.critic1_old(obs_next, act_next)
+ target_Q2 = self.critic2_old(obs_next, act_next)
+
+ target_Q = torch.min(target_Q1, target_Q2) - self._alpha * new_log_pi
+
+ target_Q = \
+ rew + self._gamma * (1 - batch.done) * target_Q.flatten()
+ # shape: (batch_size)
+
+ # compute critic loss
+ current_Q1 = self.critic1(obs, act).flatten()
+ current_Q2 = self.critic2(obs, act).flatten()
+ # shape: (batch_size)
+
+ critic1_loss = F.mse_loss(current_Q1, target_Q)
+ critic2_loss = F.mse_loss(current_Q2, target_Q)
+
+ # CQL
+ random_actions = torch.FloatTensor(
+ batch_size * self.num_repeat_actions, act.shape[-1]
+ ).uniform_(-self.min_action, self.max_action).to(self.device)
+
+ obs_len = len(obs.shape)
+ repeat_size = [1, self.num_repeat_actions] + [1] * (obs_len - 1)
+ view_size = [batch_size * self.num_repeat_actions] + list(obs.shape[1:])
+ tmp_obs = obs.unsqueeze(1).repeat(*repeat_size).view(*view_size)
+ tmp_obs_next = obs_next.unsqueeze(1).repeat(*repeat_size).view(*view_size)
+ # tmp_obs & tmp_obs_next: (batch_size * num_repeat, state_dim)
+
+ current_pi_value1, current_pi_value2 = self.calc_pi_values(tmp_obs, tmp_obs)
+ next_pi_value1, next_pi_value2 = self.calc_pi_values(tmp_obs_next, tmp_obs)
+
+ random_value1, random_value2 = self.calc_random_values(tmp_obs, random_actions)
+
+ for value in [
+ current_pi_value1, current_pi_value2, next_pi_value1, next_pi_value2,
+ random_value1, random_value2
+ ]:
+ value.reshape(batch_size, self.num_repeat_actions, 1)
+
+ # cat q values
+ cat_q1 = torch.cat([random_value1, current_pi_value1, next_pi_value1], 1)
+ cat_q2 = torch.cat([random_value2, current_pi_value2, next_pi_value2], 1)
+ # shape: (batch_size, 3 * num_repeat, 1)
+
+ cql1_scaled_loss = \
+ torch.logsumexp(cat_q1 / self.temperature, dim=1).mean() * \
+ self.cql_weight * self.temperature - current_Q1.mean() * \
+ self.cql_weight
+ cql2_scaled_loss = \
+ torch.logsumexp(cat_q2 / self.temperature, dim=1).mean() * \
+ self.cql_weight * self.temperature - current_Q2.mean() * \
+ self.cql_weight
+ # shape: (1)
+
+ if self.with_lagrange:
+ cql_alpha = torch.clamp(
+ self.cql_log_alpha.exp(),
+ self.alpha_min,
+ self.alpha_max,
+ )
+ cql1_scaled_loss = \
+ cql_alpha * (cql1_scaled_loss - self.lagrange_threshold)
+ cql2_scaled_loss = \
+ cql_alpha * (cql2_scaled_loss - self.lagrange_threshold)
+
+ self.cql_alpha_optim.zero_grad()
+ cql_alpha_loss = -(cql1_scaled_loss + cql2_scaled_loss) * 0.5
+ cql_alpha_loss.backward(retain_graph=True)
+ self.cql_alpha_optim.step()
+
+ critic1_loss = critic1_loss + cql1_scaled_loss
+ critic2_loss = critic2_loss + cql2_scaled_loss
+
+ # update critic
+ self.critic1_optim.zero_grad()
+ critic1_loss.backward(retain_graph=True)
+ # clip grad, prevent the vanishing gradient problem
+ # It doesn't seem necessary
+ clip_grad_norm_(self.critic1.parameters(), self.clip_grad)
+ self.critic1_optim.step()
+
+ self.critic2_optim.zero_grad()
+ critic2_loss.backward()
+ clip_grad_norm_(self.critic2.parameters(), self.clip_grad)
+ self.critic2_optim.step()
+
+ self.sync_weight()
+
+ result = {
+ "loss/actor": actor_loss.item(),
+ "loss/critic1": critic1_loss.item(),
+ "loss/critic2": critic2_loss.item(),
+ }
+ if self._is_auto_alpha:
+ result["loss/alpha"] = alpha_loss.item()
+ result["alpha"] = self._alpha.item() # type: ignore
+ if self.with_lagrange:
+ result["loss/cql_alpha"] = cql_alpha_loss.item()
+ result["cql_alpha"] = cql_alpha.item()
+ return result
diff --git a/tianshou/policy/imitation/discrete_bcq.py b/tianshou/policy/imitation/discrete_bcq.py
new file mode 100644
index 0000000..1f713b5
--- /dev/null
+++ b/tianshou/policy/imitation/discrete_bcq.py
@@ -0,0 +1,129 @@
+import math
+from typing import Any, Dict, Optional, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from tianshou.data import Batch, ReplayBuffer, to_torch
+from tianshou.policy import DQNPolicy
+
+
+class DiscreteBCQPolicy(DQNPolicy):
+ """Implementation of discrete BCQ algorithm. arXiv:1910.01708.
+
+ :param torch.nn.Module model: a model following the rules in
+ :class:`~tianshou.policy.BasePolicy`. (s -> q_value)
+ :param torch.nn.Module imitator: a model following the rules in
+ :class:`~tianshou.policy.BasePolicy`. (s -> imitation_logits)
+ :param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
+ :param float discount_factor: in [0, 1].
+ :param int estimation_step: the number of steps to look ahead. Default to 1.
+ :param int target_update_freq: the target network update frequency.
+ :param float eval_eps: the epsilon-greedy noise added in evaluation.
+ :param float unlikely_action_threshold: the threshold (tau) for unlikely
+ actions, as shown in Equ. (17) in the paper. Default to 0.3.
+ :param float imitation_logits_penalty: regularization weight for imitation
+ logits. Default to 1e-2.
+ :param bool reward_normalization: normalize the reward to Normal(0, 1).
+ Default to False.
+ :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
+ optimizer in each policy.update(). Default to None (no lr_scheduler).
+
+ .. seealso::
+
+ Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
+ explanation.
+ """
+
+ def __init__(
+ self,
+ model: torch.nn.Module,
+ imitator: torch.nn.Module,
+ optim: torch.optim.Optimizer,
+ discount_factor: float = 0.99,
+ estimation_step: int = 1,
+ target_update_freq: int = 8000,
+ eval_eps: float = 1e-3,
+ unlikely_action_threshold: float = 0.3,
+ imitation_logits_penalty: float = 1e-2,
+ reward_normalization: bool = False,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(
+ model, optim, discount_factor, estimation_step, target_update_freq,
+ reward_normalization, **kwargs
+ )
+ assert target_update_freq > 0, "BCQ needs target network setting."
+ self.imitator = imitator
+ assert 0.0 <= unlikely_action_threshold < 1.0, \
+ "unlikely_action_threshold should be in [0, 1)"
+ if unlikely_action_threshold > 0:
+ self._log_tau = math.log(unlikely_action_threshold)
+ else:
+ self._log_tau = -np.inf
+ assert 0.0 <= eval_eps < 1.0
+ self.eps = eval_eps
+ self._weight_reg = imitation_logits_penalty
+
+ def train(self, mode: bool = True) -> "DiscreteBCQPolicy":
+ self.training = mode
+ self.model.train(mode)
+ self.imitator.train(mode)
+ return self
+
+ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
+ batch = buffer[indices] # batch.obs_next: s_{t+n}
+ # target_Q = Q_old(s_, argmax(Q_new(s_, *)))
+ act = self(batch, input="obs_next").act
+ target_q, _ = self.model_old(batch.obs_next)
+ target_q = target_q[np.arange(len(act)), act]
+ return target_q
+
+ def forward( # type: ignore
+ self,
+ batch: Batch,
+ state: Optional[Union[dict, Batch, np.ndarray]] = None,
+ input: str = "obs",
+ **kwargs: Any,
+ ) -> Batch:
+ obs = batch[input]
+ q_value, state = self.model(obs, state=state, info=batch.info)
+ if not hasattr(self, "max_action_num"):
+ self.max_action_num = q_value.shape[1]
+ imitation_logits, _ = self.imitator(obs, state=state, info=batch.info)
+
+ # mask actions for argmax
+ ratio = imitation_logits - imitation_logits.max(dim=-1, keepdim=True).values
+ mask = (ratio < self._log_tau).float()
+ act = (q_value - np.inf * mask).argmax(dim=-1)
+
+ return Batch(
+ act=act, state=state, q_value=q_value, imitation_logits=imitation_logits
+ )
+
+ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
+ if self._iter % self._freq == 0:
+ self.sync_weight()
+ self._iter += 1
+
+ target_q = batch.returns.flatten()
+ result = self(batch)
+ imitation_logits = result.imitation_logits
+ current_q = result.q_value[np.arange(len(target_q)), batch.act]
+ act = to_torch(batch.act, dtype=torch.long, device=target_q.device)
+ q_loss = F.smooth_l1_loss(current_q, target_q)
+ i_loss = F.nll_loss(F.log_softmax(imitation_logits, dim=-1), act)
+ reg_loss = imitation_logits.pow(2).mean()
+ loss = q_loss + i_loss + self._weight_reg * reg_loss
+
+ self.optim.zero_grad()
+ loss.backward()
+ self.optim.step()
+
+ return {
+ "loss": loss.item(),
+ "loss/q": q_loss.item(),
+ "loss/i": i_loss.item(),
+ "loss/reg": reg_loss.item(),
+ }
diff --git a/tianshou/policy/imitation/discrete_cql.py b/tianshou/policy/imitation/discrete_cql.py
new file mode 100644
index 0000000..217e960
--- /dev/null
+++ b/tianshou/policy/imitation/discrete_cql.py
@@ -0,0 +1,84 @@
+from typing import Any, Dict
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from tianshou.data import Batch, to_torch
+from tianshou.policy import QRDQNPolicy
+
+
+class DiscreteCQLPolicy(QRDQNPolicy):
+ """Implementation of discrete Conservative Q-Learning algorithm. arXiv:2006.04779.
+
+ :param torch.nn.Module model: a model following the rules in
+ :class:`~tianshou.policy.BasePolicy`. (s -> logits)
+ :param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
+ :param float discount_factor: in [0, 1].
+ :param int num_quantiles: the number of quantile midpoints in the inverse
+ cumulative distribution function of the value. Default to 200.
+ :param int estimation_step: the number of steps to look ahead. Default to 1.
+ :param int target_update_freq: the target network update frequency (0 if
+ you do not use the target network).
+ :param bool reward_normalization: normalize the reward to Normal(0, 1).
+ Default to False.
+ :param float min_q_weight: the weight for the cql loss.
+ :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
+ optimizer in each policy.update(). Default to None (no lr_scheduler).
+
+ .. seealso::
+ Please refer to :class:`~tianshou.policy.QRDQNPolicy` for more detailed
+ explanation.
+ """
+
+ def __init__(
+ self,
+ model: torch.nn.Module,
+ optim: torch.optim.Optimizer,
+ discount_factor: float = 0.99,
+ num_quantiles: int = 200,
+ estimation_step: int = 1,
+ target_update_freq: int = 0,
+ reward_normalization: bool = False,
+ min_q_weight: float = 10.0,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(
+ model, optim, discount_factor, num_quantiles, estimation_step,
+ target_update_freq, reward_normalization, **kwargs
+ )
+ self._min_q_weight = min_q_weight
+
+ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
+ if self._target and self._iter % self._freq == 0:
+ self.sync_weight()
+ self.optim.zero_grad()
+ weight = batch.pop("weight", 1.0)
+ all_dist = self(batch).logits
+ act = to_torch(batch.act, dtype=torch.long, device=all_dist.device)
+ curr_dist = all_dist[np.arange(len(act)), act, :].unsqueeze(2)
+ target_dist = batch.returns.unsqueeze(1)
+ # calculate each element's difference between curr_dist and target_dist
+ dist_diff = F.smooth_l1_loss(target_dist, curr_dist, reduction="none")
+ huber_loss = (
+ dist_diff *
+ (self.tau_hat - (target_dist - curr_dist).detach().le(0.).float()).abs()
+ ).sum(-1).mean(1)
+ qr_loss = (huber_loss * weight).mean()
+ # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/
+ # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130
+ batch.weight = dist_diff.detach().abs().sum(-1).mean(1) # prio-buffer
+ # add CQL loss
+ q = self.compute_q_value(all_dist, None)
+ dataset_expec = q.gather(1, act.unsqueeze(1)).mean()
+ negative_sampling = q.logsumexp(1).mean()
+ min_q_loss = negative_sampling - dataset_expec
+ loss = qr_loss + min_q_loss * self._min_q_weight
+ loss.backward()
+ self.optim.step()
+ self._iter += 1
+ return {
+ "loss": loss.item(),
+ "loss/qr": qr_loss.item(),
+ "loss/cql": min_q_loss.item(),
+ }
diff --git a/tianshou/policy/imitation/discrete_crr.py b/tianshou/policy/imitation/discrete_crr.py
new file mode 100644
index 0000000..edbd25d
--- /dev/null
+++ b/tianshou/policy/imitation/discrete_crr.py
@@ -0,0 +1,126 @@
+from copy import deepcopy
+from typing import Any, Dict
+
+import torch
+import torch.nn.functional as F
+from torch.distributions import Categorical
+
+from tianshou.data import Batch, to_torch, to_torch_as
+from tianshou.policy.modelfree.pg import PGPolicy
+
+
+class DiscreteCRRPolicy(PGPolicy):
+ r"""Implementation of discrete Critic Regularized Regression. arXiv:2006.15134.
+
+ :param torch.nn.Module actor: the actor network following the rules in
+ :class:`~tianshou.policy.BasePolicy`. (s -> logits)
+ :param torch.nn.Module critic: the action-value critic (i.e., Q function)
+ network. (s -> Q(s, \*))
+ :param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
+ :param float discount_factor: in [0, 1]. Default to 0.99.
+ :param str policy_improvement_mode: type of the weight function f. Possible
+ values: "binary"/"exp"/"all". Default to "exp".
+ :param float ratio_upper_bound: when policy_improvement_mode is "exp", the value
+ of the exp function is upper-bounded by this parameter. Default to 20.
+ :param float beta: when policy_improvement_mode is "exp", this is the denominator
+ of the exp function. Default to 1.
+ :param float min_q_weight: weight for CQL loss/regularizer. Default to 10.
+ :param int target_update_freq: the target network update frequency (0 if
+ you do not use the target network). Default to 0.
+ :param bool reward_normalization: normalize the reward to Normal(0, 1).
+ Default to False.
+ :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
+ optimizer in each policy.update(). Default to None (no lr_scheduler).
+
+ .. seealso::
+ Please refer to :class:`~tianshou.policy.PGPolicy` for more detailed
+ explanation.
+ """
+
+ def __init__(
+ self,
+ actor: torch.nn.Module,
+ critic: torch.nn.Module,
+ optim: torch.optim.Optimizer,
+ discount_factor: float = 0.99,
+ policy_improvement_mode: str = "exp",
+ ratio_upper_bound: float = 20.0,
+ beta: float = 1.0,
+ min_q_weight: float = 10.0,
+ target_update_freq: int = 0,
+ reward_normalization: bool = False,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(
+ actor,
+ optim,
+ lambda x: Categorical(logits=x), # type: ignore
+ discount_factor,
+ reward_normalization,
+ **kwargs,
+ )
+ self.critic = critic
+ self._target = target_update_freq > 0
+ self._freq = target_update_freq
+ self._iter = 0
+ if self._target:
+ self.actor_old = deepcopy(self.actor)
+ self.actor_old.eval()
+ self.critic_old = deepcopy(self.critic)
+ self.critic_old.eval()
+ else:
+ self.actor_old = self.actor
+ self.critic_old = self.critic
+ assert policy_improvement_mode in ["exp", "binary", "all"]
+ self._policy_improvement_mode = policy_improvement_mode
+ self._ratio_upper_bound = ratio_upper_bound
+ self._beta = beta
+ self._min_q_weight = min_q_weight
+
+ def sync_weight(self) -> None:
+ self.actor_old.load_state_dict(self.actor.state_dict())
+ self.critic_old.load_state_dict(self.critic.state_dict())
+
+ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: # type: ignore
+ if self._target and self._iter % self._freq == 0:
+ self.sync_weight()
+ self.optim.zero_grad()
+ q_t = self.critic(batch.obs)
+ act = to_torch(batch.act, dtype=torch.long, device=q_t.device)
+ qa_t = q_t.gather(1, act.unsqueeze(1))
+ # Critic loss
+ with torch.no_grad():
+ target_a_t, _ = self.actor_old(batch.obs_next)
+ target_m = Categorical(logits=target_a_t)
+ q_t_target = self.critic_old(batch.obs_next)
+ rew = to_torch_as(batch.rew, q_t_target)
+ expected_target_q = (q_t_target * target_m.probs).sum(-1, keepdim=True)
+ expected_target_q[batch.done > 0] = 0.0
+ target = rew.unsqueeze(1) + self._gamma * expected_target_q
+ critic_loss = 0.5 * F.mse_loss(qa_t, target)
+ # Actor loss
+ act_target, _ = self.actor(batch.obs)
+ dist = Categorical(logits=act_target)
+ expected_policy_q = (q_t * dist.probs).sum(-1, keepdim=True)
+ advantage = qa_t - expected_policy_q
+ if self._policy_improvement_mode == "binary":
+ actor_loss_coef = (advantage > 0).float()
+ elif self._policy_improvement_mode == "exp":
+ actor_loss_coef = (
+ (advantage / self._beta).exp().clamp(0, self._ratio_upper_bound)
+ )
+ else:
+ actor_loss_coef = 1.0 # effectively behavior cloning
+ actor_loss = (-dist.log_prob(act) * actor_loss_coef).mean()
+ # CQL loss/regularizer
+ min_q_loss = (q_t.logsumexp(1) - qa_t).mean()
+ loss = actor_loss + critic_loss + self._min_q_weight * min_q_loss
+ loss.backward()
+ self.optim.step()
+ self._iter += 1
+ return {
+ "loss": loss.item(),
+ "loss/actor": actor_loss.item(),
+ "loss/critic": critic_loss.item(),
+ "loss/cql": min_q_loss.item(),
+ }
diff --git a/tianshou/policy/imitation/gail.py b/tianshou/policy/imitation/gail.py
new file mode 100644
index 0000000..b85f52d
--- /dev/null
+++ b/tianshou/policy/imitation/gail.py
@@ -0,0 +1,141 @@
+from typing import Any, Dict, List, Optional, Type
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch
+from tianshou.policy import PPOPolicy
+
+
+class GAILPolicy(PPOPolicy):
+ r"""Implementation of Generative Adversarial Imitation Learning. arXiv:1606.03476.
+
+ :param torch.nn.Module actor: the actor network following the rules in
+ :class:`~tianshou.policy.BasePolicy`. (s -> logits)
+ :param torch.nn.Module critic: the critic network. (s -> V(s))
+ :param torch.optim.Optimizer optim: the optimizer for actor and critic network.
+ :param dist_fn: distribution class for computing the action.
+ :type dist_fn: Type[torch.distributions.Distribution]
+ :param ReplayBuffer expert_buffer: the replay buffer contains expert experience.
+ :param torch.nn.Module disc_net: the discriminator network with input dim equals
+ state dim plus action dim and output dim equals 1.
+ :param torch.optim.Optimizer disc_optim: the optimizer for the discriminator
+ network.
+ :param int disc_update_num: the number of discriminator grad steps per model grad
+ step. Default to 4.
+ :param float discount_factor: in [0, 1]. Default to 0.99.
+ :param float eps_clip: :math:`\epsilon` in :math:`L_{CLIP}` in the original
+ paper. Default to 0.2.
+ :param float dual_clip: a parameter c mentioned in arXiv:1912.09729 Equ. 5,
+ where c > 1 is a constant indicating the lower bound.
+ Default to 5.0 (set None if you do not want to use it).
+ :param bool value_clip: a parameter mentioned in arXiv:1811.02553 Sec. 4.1.
+ Default to True.
+ :param bool advantage_normalization: whether to do per mini-batch advantage
+ normalization. Default to True.
+ :param bool recompute_advantage: whether to recompute advantage every update
+ repeat according to https://arxiv.org/pdf/2006.05990.pdf Sec. 3.5.
+ Default to False.
+ :param float vf_coef: weight for value loss. Default to 0.5.
+ :param float ent_coef: weight for entropy loss. Default to 0.01.
+ :param float max_grad_norm: clipping gradients in back propagation. Default to
+ None.
+ :param float gae_lambda: in [0, 1], param for Generalized Advantage Estimation.
+ Default to 0.95.
+ :param bool reward_normalization: normalize estimated values to have std close
+ to 1, also normalize the advantage to Normal(0, 1). Default to False.
+ :param int max_batchsize: the maximum size of the batch when computing GAE,
+ depends on the size of available memory and the memory cost of the model;
+ should be as large as possible within the memory constraint. Default to 256.
+ :param bool action_scaling: whether to map actions from range [-1, 1] to range
+ [action_spaces.low, action_spaces.high]. Default to True.
+ :param str action_bound_method: method to bound action to range [-1, 1], can be
+ either "clip" (for simply clipping the action), "tanh" (for applying tanh
+ squashing) for now, or empty string for no bounding. Default to "clip".
+ :param Optional[gym.Space] action_space: env's action space, mandatory if you want
+ to use option "action_scaling" or "action_bound_method". Default to None.
+ :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
+ optimizer in each policy.update(). Default to None (no lr_scheduler).
+ :param bool deterministic_eval: whether to use deterministic action instead of
+ stochastic action sampled by the policy. Default to False.
+ :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
+ optimizer in each policy.update(). Default to None (no lr_scheduler).
+
+ .. seealso::
+
+ Please refer to :class:`~tianshou.policy.PPOPolicy` for more detailed
+ explanation.
+ """
+
+ def __init__(
+ self,
+ actor: torch.nn.Module,
+ critic: torch.nn.Module,
+ optim: torch.optim.Optimizer,
+ dist_fn: Type[torch.distributions.Distribution],
+ expert_buffer: ReplayBuffer,
+ disc_net: torch.nn.Module,
+ disc_optim: torch.optim.Optimizer,
+ disc_update_num: int = 4,
+ eps_clip: float = 0.2,
+ dual_clip: Optional[float] = None,
+ value_clip: bool = False,
+ advantage_normalization: bool = True,
+ recompute_advantage: bool = False,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(
+ actor, critic, optim, dist_fn, eps_clip, dual_clip, value_clip,
+ advantage_normalization, recompute_advantage, **kwargs
+ )
+ self.disc_net = disc_net
+ self.disc_optim = disc_optim
+ self.disc_update_num = disc_update_num
+ self.expert_buffer = expert_buffer
+ self.action_dim = actor.output_dim
+
+ def process_fn(
+ self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
+ ) -> Batch:
+ """Pre-process the data from the provided replay buffer.
+
+ Used in :meth:`update`. Check out :ref:`process_fn` for more information.
+ """
+ # update reward
+ with torch.no_grad():
+ batch.rew = to_numpy(-F.logsigmoid(-self.disc(batch)).flatten())
+ return super().process_fn(batch, buffer, indices)
+
+ def disc(self, batch: Batch) -> torch.Tensor:
+ obs = to_torch(batch.obs, device=self.disc_net.device)
+ act = to_torch(batch.act, device=self.disc_net.device)
+ return self.disc_net(torch.cat([obs, act], dim=1))
+
+ def learn( # type: ignore
+ self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
+ ) -> Dict[str, List[float]]:
+ # update discriminator
+ losses = []
+ acc_pis = []
+ acc_exps = []
+ bsz = len(batch) // self.disc_update_num
+ for b in batch.split(bsz, merge_last=True):
+ logits_pi = self.disc(b)
+ exp_b = self.expert_buffer.sample(bsz)[0]
+ logits_exp = self.disc(exp_b)
+ loss_pi = -F.logsigmoid(-logits_pi).mean()
+ loss_exp = -F.logsigmoid(logits_exp).mean()
+ loss_disc = loss_pi + loss_exp
+ self.disc_optim.zero_grad()
+ loss_disc.backward()
+ self.disc_optim.step()
+ losses.append(loss_disc.item())
+ acc_pis.append((logits_pi < 0).float().mean().item())
+ acc_exps.append((logits_exp > 0).float().mean().item())
+ # update policy
+ res = super().learn(batch, batch_size, repeat, **kwargs)
+ res["loss/disc"] = losses
+ res["stats/acc_pi"] = acc_pis
+ res["stats/acc_exp"] = acc_exps
+ return res
diff --git a/tianshou/policy/imitation/td3_bc.py b/tianshou/policy/imitation/td3_bc.py
new file mode 100644
index 0000000..9746659
--- /dev/null
+++ b/tianshou/policy/imitation/td3_bc.py
@@ -0,0 +1,107 @@
+from typing import Any, Dict, Optional
+
+import torch
+import torch.nn.functional as F
+
+from tianshou.data import Batch, to_torch_as
+from tianshou.exploration import BaseNoise, GaussianNoise
+from tianshou.policy import TD3Policy
+
+
+class TD3BCPolicy(TD3Policy):
+ """Implementation of TD3+BC. arXiv:2106.06860.
+
+ :param torch.nn.Module actor: the actor network following the rules in
+ :class:`~tianshou.policy.BasePolicy`. (s -> logits)
+ :param torch.optim.Optimizer actor_optim: the optimizer for actor network.
+ :param torch.nn.Module critic1: the first critic network. (s, a -> Q(s, a))
+ :param torch.optim.Optimizer critic1_optim: the optimizer for the first
+ critic network.
+ :param torch.nn.Module critic2: the second critic network. (s, a -> Q(s, a))
+ :param torch.optim.Optimizer critic2_optim: the optimizer for the second
+ critic network.
+ :param float tau: param for soft update of the target network. Default to 0.005.
+ :param float gamma: discount factor, in [0, 1]. Default to 0.99.
+ :param float exploration_noise: the exploration noise, add to the action.
+ Default to ``GaussianNoise(sigma=0.1)``
+ :param float policy_noise: the noise used in updating policy network.
+ Default to 0.2.
+ :param int update_actor_freq: the update frequency of actor network.
+ Default to 2.
+ :param float noise_clip: the clipping range used in updating policy network.
+ Default to 0.5.
+ :param float alpha: the value of alpha, which controls the weight for TD3 learning
+ relative to behavior cloning.
+ :param bool reward_normalization: normalize the reward to Normal(0, 1).
+ Default to False.
+ :param bool action_scaling: whether to map actions from range [-1, 1] to range
+ [action_spaces.low, action_spaces.high]. Default to True.
+ :param str action_bound_method: method to bound action to range [-1, 1], can be
+ either "clip" (for simply clipping the action) or empty string for no bounding.
+ Default to "clip".
+ :param Optional[gym.Space] action_space: env's action space, mandatory if you want
+ to use option "action_scaling" or "action_bound_method". Default to None.
+ :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
+ optimizer in each policy.update(). Default to None (no lr_scheduler).
+
+ .. seealso::
+
+ Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
+ explanation.
+ """
+
+ def __init__(
+ self,
+ actor: torch.nn.Module,
+ actor_optim: torch.optim.Optimizer,
+ critic1: torch.nn.Module,
+ critic1_optim: torch.optim.Optimizer,
+ critic2: torch.nn.Module,
+ critic2_optim: torch.optim.Optimizer,
+ tau: float = 0.005,
+ gamma: float = 0.99,
+ exploration_noise: Optional[BaseNoise] = GaussianNoise(sigma=0.1),
+ policy_noise: float = 0.2,
+ update_actor_freq: int = 2,
+ noise_clip: float = 0.5,
+ alpha: float = 2.5,
+ reward_normalization: bool = False,
+ estimation_step: int = 1,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(
+ actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, tau,
+ gamma, exploration_noise, policy_noise, update_actor_freq, noise_clip,
+ reward_normalization, estimation_step, **kwargs
+ )
+ self._alpha = alpha
+
+ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
+ # critic 1&2
+ td1, critic1_loss = self._mse_optimizer(
+ batch, self.critic1, self.critic1_optim
+ )
+ td2, critic2_loss = self._mse_optimizer(
+ batch, self.critic2, self.critic2_optim
+ )
+ batch.weight = (td1 + td2) / 2.0 # prio-buffer
+
+ # actor
+ if self._cnt % self._freq == 0:
+ act = self(batch, eps=0.0).act
+ q_value = self.critic1(batch.obs, act)
+ lmbda = self._alpha / q_value.abs().mean().detach()
+ actor_loss = -lmbda * q_value.mean() + F.mse_loss(
+ act, to_torch_as(batch.act, act)
+ )
+ self.actor_optim.zero_grad()
+ actor_loss.backward()
+ self._last = actor_loss.item()
+ self.actor_optim.step()
+ self.sync_weight()
+ self._cnt += 1
+ return {
+ "loss/actor": self._last,
+ "loss/critic1": critic1_loss.item(),
+ "loss/critic2": critic2_loss.item(),
+ }
diff --git a/tianshou/policy/modelbased/__init__.py b/tianshou/policy/modelbased/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tianshou/policy/modelbased/icm.py b/tianshou/policy/modelbased/icm.py
new file mode 100644
index 0000000..97fcde2
--- /dev/null
+++ b/tianshou/policy/modelbased/icm.py
@@ -0,0 +1,123 @@
+from typing import Any, Dict, Optional, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch
+from tianshou.policy import BasePolicy
+from tianshou.utils.net.discrete import IntrinsicCuriosityModule
+
+
+class ICMPolicy(BasePolicy):
+ """Implementation of Intrinsic Curiosity Module. arXiv:1705.05363.
+
+ :param BasePolicy policy: a base policy to add ICM to.
+ :param IntrinsicCuriosityModule model: the ICM model.
+ :param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
+ :param float lr_scale: the scaling factor for ICM learning.
+ :param float forward_loss_weight: the weight for forward model loss.
+ :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
+ optimizer in each policy.update(). Default to None (no lr_scheduler).
+
+ .. seealso::
+
+ Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
+ explanation.
+ """
+
+ def __init__(
+ self,
+ policy: BasePolicy,
+ model: IntrinsicCuriosityModule,
+ optim: torch.optim.Optimizer,
+ lr_scale: float,
+ reward_scale: float,
+ forward_loss_weight: float,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.policy = policy
+ self.model = model
+ self.optim = optim
+ self.lr_scale = lr_scale
+ self.reward_scale = reward_scale
+ self.forward_loss_weight = forward_loss_weight
+
+ def train(self, mode: bool = True) -> "ICMPolicy":
+ """Set the module in training mode."""
+ self.policy.train(mode)
+ self.training = mode
+ self.model.train(mode)
+ return self
+
+ def forward(
+ self,
+ batch: Batch,
+ state: Optional[Union[dict, Batch, np.ndarray]] = None,
+ **kwargs: Any,
+ ) -> Batch:
+ """Compute action over the given batch data by inner policy.
+
+ .. seealso::
+
+ Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
+ more detailed explanation.
+ """
+ return self.policy.forward(batch, state, **kwargs)
+
+ def exploration_noise(self, act: Union[np.ndarray, Batch],
+ batch: Batch) -> Union[np.ndarray, Batch]:
+ return self.policy.exploration_noise(act, batch)
+
+ def set_eps(self, eps: float) -> None:
+ """Set the eps for epsilon-greedy exploration."""
+ if hasattr(self.policy, "set_eps"):
+ self.policy.set_eps(eps) # type: ignore
+ else:
+ raise NotImplementedError()
+
+ def process_fn(
+ self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
+ ) -> Batch:
+ """Pre-process the data from the provided replay buffer.
+
+ Used in :meth:`update`. Check out :ref:`process_fn` for more information.
+ """
+ mse_loss, act_hat = self.model(batch.obs, batch.act, batch.obs_next)
+ batch.policy = Batch(orig_rew=batch.rew, act_hat=act_hat, mse_loss=mse_loss)
+ batch.rew += to_numpy(mse_loss * self.reward_scale)
+ return self.policy.process_fn(batch, buffer, indices)
+
+ def post_process_fn(
+ self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
+ ) -> None:
+ """Post-process the data from the provided replay buffer.
+
+ Typical usage is to update the sampling weight in prioritized
+ experience replay. Used in :meth:`update`.
+ """
+ self.policy.post_process_fn(batch, buffer, indices)
+ batch.rew = batch.policy.orig_rew # restore original reward
+
+ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
+ res = self.policy.learn(batch, **kwargs)
+ self.optim.zero_grad()
+ act_hat = batch.policy.act_hat
+ act = to_torch(batch.act, dtype=torch.long, device=act_hat.device)
+ inverse_loss = F.cross_entropy(act_hat, act).mean()
+ forward_loss = batch.policy.mse_loss.mean()
+ loss = (
+ (1 - self.forward_loss_weight) * inverse_loss +
+ self.forward_loss_weight * forward_loss
+ ) * self.lr_scale
+ loss.backward()
+ self.optim.step()
+ res.update(
+ {
+ "loss/icm": loss.item(),
+ "loss/icm/forward": forward_loss.item(),
+ "loss/icm/inverse": inverse_loss.item()
+ }
+ )
+ return res
diff --git a/tianshou/policy/modelbased/psrl.py b/tianshou/policy/modelbased/psrl.py
new file mode 100644
index 0000000..8e7473a
--- /dev/null
+++ b/tianshou/policy/modelbased/psrl.py
@@ -0,0 +1,221 @@
+from typing import Any, Dict, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from tianshou.data import Batch
+from tianshou.policy import BasePolicy
+
+
+class PSRLModel(object):
+ """Implementation of Posterior Sampling Reinforcement Learning Model.
+
+ :param np.ndarray trans_count_prior: dirichlet prior (alphas), with shape
+ (n_state, n_action, n_state).
+ :param np.ndarray rew_mean_prior: means of the normal priors of rewards,
+ with shape (n_state, n_action).
+ :param np.ndarray rew_std_prior: standard deviations of the normal priors
+ of rewards, with shape (n_state, n_action).
+ :param float discount_factor: in [0, 1].
+ :param float epsilon: for precision control in value iteration.
+ :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
+ optimizer in each policy.update(). Default to None (no lr_scheduler).
+ """
+
+ def __init__(
+ self,
+ trans_count_prior: np.ndarray,
+ rew_mean_prior: np.ndarray,
+ rew_std_prior: np.ndarray,
+ discount_factor: float,
+ epsilon: float,
+ ) -> None:
+ self.trans_count = trans_count_prior
+ self.n_state, self.n_action = rew_mean_prior.shape
+ self.rew_mean = rew_mean_prior
+ self.rew_std = rew_std_prior
+ self.rew_square_sum = np.zeros_like(rew_mean_prior)
+ self.rew_std_prior = rew_std_prior
+ self.discount_factor = discount_factor
+ self.rew_count = np.full(rew_mean_prior.shape, epsilon) # no weight
+ self.eps = epsilon
+ self.policy: np.ndarray
+ self.value = np.zeros(self.n_state)
+ self.updated = False
+ self.__eps = np.finfo(np.float32).eps.item()
+
+ def observe(
+ self,
+ trans_count: np.ndarray,
+ rew_sum: np.ndarray,
+ rew_square_sum: np.ndarray,
+ rew_count: np.ndarray,
+ ) -> None:
+ """Add data into memory pool.
+
+ For rewards, we have a normal prior at first. After we observed a
+ reward for a given state-action pair, we use the mean value of our
+ observations instead of the prior mean as the posterior mean. The
+ standard deviations are in inverse proportion to the number of the
+ corresponding observations.
+
+ :param np.ndarray trans_count: the number of observations, with shape
+ (n_state, n_action, n_state).
+ :param np.ndarray rew_sum: total rewards, with shape
+ (n_state, n_action).
+ :param np.ndarray rew_square_sum: total rewards' squares, with shape
+ (n_state, n_action).
+ :param np.ndarray rew_count: the number of rewards, with shape
+ (n_state, n_action).
+ """
+ self.updated = False
+ self.trans_count += trans_count
+ sum_count = self.rew_count + rew_count
+ self.rew_mean = (self.rew_mean * self.rew_count + rew_sum) / sum_count
+ self.rew_square_sum += rew_square_sum
+ raw_std2 = self.rew_square_sum / sum_count - self.rew_mean**2
+ self.rew_std = np.sqrt(
+ 1 / (sum_count / (raw_std2 + self.__eps) + 1 / self.rew_std_prior**2)
+ )
+ self.rew_count = sum_count
+
+ def sample_trans_prob(self) -> np.ndarray:
+ sample_prob = torch.distributions.Dirichlet(
+ torch.from_numpy(self.trans_count)
+ ).sample().numpy()
+ return sample_prob
+
+ def sample_reward(self) -> np.ndarray:
+ return np.random.normal(self.rew_mean, self.rew_std)
+
+ def solve_policy(self) -> None:
+ self.updated = True
+ self.policy, self.value = self.value_iteration(
+ self.sample_trans_prob(),
+ self.sample_reward(),
+ self.discount_factor,
+ self.eps,
+ self.value,
+ )
+
+ @staticmethod
+ def value_iteration(
+ trans_prob: np.ndarray,
+ rew: np.ndarray,
+ discount_factor: float,
+ eps: float,
+ value: np.ndarray,
+ ) -> Tuple[np.ndarray, np.ndarray]:
+ """Value iteration solver for MDPs.
+
+ :param np.ndarray trans_prob: transition probabilities, with shape
+ (n_state, n_action, n_state).
+ :param np.ndarray rew: rewards, with shape (n_state, n_action).
+ :param float eps: for precision control.
+ :param float discount_factor: in [0, 1].
+ :param np.ndarray value: the initialize value of value array, with
+ shape (n_state, ).
+
+ :return: the optimal policy with shape (n_state, ).
+ """
+ Q = rew + discount_factor * trans_prob.dot(value)
+ new_value = Q.max(axis=1)
+ while not np.allclose(new_value, value, eps):
+ value = new_value
+ Q = rew + discount_factor * trans_prob.dot(value)
+ new_value = Q.max(axis=1)
+ # this is to make sure if Q(s, a1) == Q(s, a2) -> choose a1/a2 randomly
+ Q += eps * np.random.randn(*Q.shape)
+ return Q.argmax(axis=1), new_value
+
+ def __call__(
+ self,
+ obs: np.ndarray,
+ state: Any = None,
+ info: Dict[str, Any] = {},
+ ) -> np.ndarray:
+ if not self.updated:
+ self.solve_policy()
+ return self.policy[obs]
+
+
+class PSRLPolicy(BasePolicy):
+ """Implementation of Posterior Sampling Reinforcement Learning.
+
+ Reference: Strens M. A Bayesian framework for reinforcement learning [C]
+ //ICML. 2000, 2000: 943-950.
+
+ :param np.ndarray trans_count_prior: dirichlet prior (alphas), with shape
+ (n_state, n_action, n_state).
+ :param np.ndarray rew_mean_prior: means of the normal priors of rewards,
+ with shape (n_state, n_action).
+ :param np.ndarray rew_std_prior: standard deviations of the normal priors
+ of rewards, with shape (n_state, n_action).
+ :param float discount_factor: in [0, 1].
+ :param float epsilon: for precision control in value iteration.
+ :param bool add_done_loop: whether to add an extra self-loop for the
+ terminal state in MDP. Default to False.
+
+ .. seealso::
+
+ Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
+ explanation.
+ """
+
+ def __init__(
+ self,
+ trans_count_prior: np.ndarray,
+ rew_mean_prior: np.ndarray,
+ rew_std_prior: np.ndarray,
+ discount_factor: float = 0.99,
+ epsilon: float = 0.01,
+ add_done_loop: bool = False,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(**kwargs)
+ assert (0.0 <= discount_factor <= 1.0), "discount factor should be in [0, 1]"
+ self.model = PSRLModel(
+ trans_count_prior, rew_mean_prior, rew_std_prior, discount_factor, epsilon
+ )
+ self._add_done_loop = add_done_loop
+
+ def forward(
+ self,
+ batch: Batch,
+ state: Optional[Union[dict, Batch, np.ndarray]] = None,
+ **kwargs: Any,
+ ) -> Batch:
+ """Compute action over the given batch data with PSRL model.
+
+ :return: A :class:`~tianshou.data.Batch` with "act" key containing
+ the action.
+
+ .. seealso::
+
+ Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
+ more detailed explanation.
+ """
+ act = self.model(batch.obs, state=state, info=batch.info)
+ return Batch(act=act)
+
+ def learn(self, batch: Batch, *args: Any, **kwargs: Any) -> Dict[str, float]:
+ n_s, n_a = self.model.n_state, self.model.n_action
+ trans_count = np.zeros((n_s, n_a, n_s))
+ rew_sum = np.zeros((n_s, n_a))
+ rew_square_sum = np.zeros((n_s, n_a))
+ rew_count = np.zeros((n_s, n_a))
+ for minibatch in batch.split(size=1):
+ obs, act, obs_next = minibatch.obs, minibatch.act, minibatch.obs_next
+ trans_count[obs, act, obs_next] += 1
+ rew_sum[obs, act] += minibatch.rew
+ rew_square_sum[obs, act] += minibatch.rew**2
+ rew_count[obs, act] += 1
+ if self._add_done_loop and minibatch.done:
+ # special operation for terminal states: add a self-loop
+ trans_count[obs_next, :, obs_next] += 1
+ rew_count[obs_next, :] += 1
+ self.model.observe(trans_count, rew_sum, rew_square_sum, rew_count)
+ return {
+ "psrl/rew_mean": float(self.model.rew_mean.mean()),
+ "psrl/rew_std": float(self.model.rew_std.mean()),
+ }
diff --git a/tianshou/policy/modelfree/__init__.py b/tianshou/policy/modelfree/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py
new file mode 100644
index 0000000..b218cbf
--- /dev/null
+++ b/tianshou/policy/modelfree/a2c.py
@@ -0,0 +1,155 @@
+from typing import Any, Dict, List, Optional, Type
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from tianshou.data import Batch, ReplayBuffer, to_torch_as
+from tianshou.policy import PGPolicy
+from tianshou.utils.net.common import ActorCritic
+
+
+class A2CPolicy(PGPolicy):
+ """Implementation of Synchronous Advantage Actor-Critic. arXiv:1602.01783.
+
+ :param torch.nn.Module actor: the actor network following the rules in
+ :class:`~tianshou.policy.BasePolicy`. (s -> logits)
+ :param torch.nn.Module critic: the critic network. (s -> V(s))
+ :param torch.optim.Optimizer optim: the optimizer for actor and critic network.
+ :param dist_fn: distribution class for computing the action.
+ :type dist_fn: Type[torch.distributions.Distribution]
+ :param float discount_factor: in [0, 1]. Default to 0.99.
+ :param float vf_coef: weight for value loss. Default to 0.5.
+ :param float ent_coef: weight for entropy loss. Default to 0.01.
+ :param float max_grad_norm: clipping gradients in back propagation. Default to
+ None.
+ :param float gae_lambda: in [0, 1], param for Generalized Advantage Estimation.
+ Default to 0.95.
+ :param bool reward_normalization: normalize estimated values to have std close to
+ 1. Default to False.
+ :param int max_batchsize: the maximum size of the batch when computing GAE,
+ depends on the size of available memory and the memory cost of the
+ model; should be as large as possible within the memory constraint.
+ Default to 256.
+ :param bool action_scaling: whether to map actions from range [-1, 1] to range
+ [action_spaces.low, action_spaces.high]. Default to True.
+ :param str action_bound_method: method to bound action to range [-1, 1], can be
+ either "clip" (for simply clipping the action), "tanh" (for applying tanh
+ squashing) for now, or empty string for no bounding. Default to "clip".
+ :param Optional[gym.Space] action_space: env's action space, mandatory if you want
+ to use option "action_scaling" or "action_bound_method". Default to None.
+ :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
+ optimizer in each policy.update(). Default to None (no lr_scheduler).
+ :param bool deterministic_eval: whether to use deterministic action instead of
+ stochastic action sampled by the policy. Default to False.
+
+ .. seealso::
+
+ Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
+ explanation.
+ """
+
+ def __init__(
+ self,
+ actor: torch.nn.Module,
+ critic: torch.nn.Module,
+ optim: torch.optim.Optimizer,
+ dist_fn: Type[torch.distributions.Distribution],
+ vf_coef: float = 0.5,
+ ent_coef: float = 0.01,
+ max_grad_norm: Optional[float] = None,
+ gae_lambda: float = 0.95,
+ max_batchsize: int = 256,
+ **kwargs: Any
+ ) -> None:
+ super().__init__(actor, optim, dist_fn, **kwargs)
+ self.critic = critic
+ assert 0.0 <= gae_lambda <= 1.0, "GAE lambda should be in [0, 1]."
+ self._lambda = gae_lambda
+ self._weight_vf = vf_coef
+ self._weight_ent = ent_coef
+ self._grad_norm = max_grad_norm
+ self._batch = max_batchsize
+ self._actor_critic = ActorCritic(self.actor, self.critic)
+
+ def process_fn(
+ self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
+ ) -> Batch:
+ batch = self._compute_returns(batch, buffer, indices)
+ batch.act = to_torch_as(batch.act, batch.v_s)
+ return batch
+
+ def _compute_returns(
+ self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
+ ) -> Batch:
+ v_s, v_s_ = [], []
+ with torch.no_grad():
+ for minibatch in batch.split(self._batch, shuffle=False, merge_last=True):
+ v_s.append(self.critic(minibatch.obs))
+ v_s_.append(self.critic(minibatch.obs_next))
+ batch.v_s = torch.cat(v_s, dim=0).flatten() # old value
+ v_s = batch.v_s.cpu().numpy()
+ v_s_ = torch.cat(v_s_, dim=0).flatten().cpu().numpy()
+ # when normalizing values, we do not minus self.ret_rms.mean to be numerically
+ # consistent with OPENAI baselines' value normalization pipeline. Emperical
+ # study also shows that "minus mean" will harm performances a tiny little bit
+ # due to unknown reasons (on Mujoco envs, not confident, though).
+ if self._rew_norm: # unnormalize v_s & v_s_
+ v_s = v_s * np.sqrt(self.ret_rms.var + self._eps)
+ v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps)
+ unnormalized_returns, advantages = self.compute_episodic_return(
+ batch,
+ buffer,
+ indices,
+ v_s_,
+ v_s,
+ gamma=self._gamma,
+ gae_lambda=self._lambda
+ )
+ if self._rew_norm:
+ batch.returns = unnormalized_returns / \
+ np.sqrt(self.ret_rms.var + self._eps)
+ self.ret_rms.update(unnormalized_returns)
+ else:
+ batch.returns = unnormalized_returns
+ batch.returns = to_torch_as(batch.returns, batch.v_s)
+ batch.adv = to_torch_as(advantages, batch.v_s)
+ return batch
+
+ def learn( # type: ignore
+ self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
+ ) -> Dict[str, List[float]]:
+ losses, actor_losses, vf_losses, ent_losses = [], [], [], []
+ for _ in range(repeat):
+ for minibatch in batch.split(batch_size, merge_last=True):
+ # calculate loss for actor
+ dist = self(minibatch).dist
+ log_prob = dist.log_prob(minibatch.act)
+ log_prob = log_prob.reshape(len(minibatch.adv), -1).transpose(0, 1)
+ actor_loss = -(log_prob * minibatch.adv).mean()
+ # calculate loss for critic
+ value = self.critic(minibatch.obs).flatten()
+ vf_loss = F.mse_loss(minibatch.returns, value)
+ # calculate regularization and overall loss
+ ent_loss = dist.entropy().mean()
+ loss = actor_loss + self._weight_vf * vf_loss \
+ - self._weight_ent * ent_loss
+ self.optim.zero_grad()
+ loss.backward()
+ if self._grad_norm: # clip large gradient
+ nn.utils.clip_grad_norm_(
+ self._actor_critic.parameters(), max_norm=self._grad_norm
+ )
+ self.optim.step()
+ actor_losses.append(actor_loss.item())
+ vf_losses.append(vf_loss.item())
+ ent_losses.append(ent_loss.item())
+ losses.append(loss.item())
+
+ return {
+ "loss": losses,
+ "loss/actor": actor_losses,
+ "loss/vf": vf_losses,
+ "loss/ent": ent_losses,
+ }
diff --git a/tianshou/policy/modelfree/bdq.py b/tianshou/policy/modelfree/bdq.py
new file mode 100644
index 0000000..b02514c
--- /dev/null
+++ b/tianshou/policy/modelfree/bdq.py
@@ -0,0 +1,144 @@
+from typing import Any, Dict, Optional, Union
+
+import numpy as np
+import torch
+
+from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch, to_torch_as
+from tianshou.policy import DQNPolicy
+from tianshou.utils.net.common import BranchingNet
+
+
+class BranchingDQNPolicy(DQNPolicy):
+ """Implementation of the Branching dual Q network arXiv:1711.08946.
+
+ :param torch.nn.Module model: a model following the rules in
+ :class:`~tianshou.policy.BasePolicy`. (s -> logits)
+ :param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
+ :param float discount_factor: in [0, 1].
+ :param int estimation_step: the number of steps to look ahead. Default to 1.
+ :param int target_update_freq: the target network update frequency (0 if
+ you do not use the target network). Default to 0.
+ :param bool reward_normalization: normalize the reward to Normal(0, 1).
+ Default to False.
+ :param bool is_double: use double network. Default to True.
+
+ .. seealso::
+
+ Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
+ explanation.
+ """
+
+ def __init__(
+ self,
+ model: BranchingNet,
+ optim: torch.optim.Optimizer,
+ discount_factor: float = 0.99,
+ estimation_step: int = 1,
+ target_update_freq: int = 0,
+ reward_normalization: bool = False,
+ is_double: bool = True,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(
+ model, optim, discount_factor, estimation_step, target_update_freq,
+ reward_normalization, is_double
+ )
+ assert estimation_step == 1, "N-step bigger than one is not supported by BDQ"
+ self.max_action_num = model.action_per_branch
+ self.num_branches = model.num_branches
+
+ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
+ batch = buffer[indices] # batch.obs_next: s_{t+n}
+ result = self(batch, input="obs_next")
+ if self._target:
+ # target_Q = Q_old(s_, argmax(Q_new(s_, *)))
+ target_q = self(batch, model="model_old", input="obs_next").logits
+ else:
+ target_q = result.logits
+ if self._is_double:
+ act = np.expand_dims(self(batch, input="obs_next").act, -1)
+ act = to_torch(act, dtype=torch.long, device=target_q.device)
+ else:
+ act = target_q.max(-1).indices.unsqueeze(-1)
+ return torch.gather(target_q, -1, act).squeeze()
+
+ def _compute_return(
+ self,
+ batch: Batch,
+ buffer: ReplayBuffer,
+ indice: np.ndarray,
+ gamma: float = 0.99,
+ ) -> Batch:
+ rew = batch.rew
+ with torch.no_grad():
+ target_q_torch = self._target_q(buffer, indice) # (bsz, ?)
+ target_q = to_numpy(target_q_torch)
+ end_flag = buffer.done.copy()
+ end_flag[buffer.unfinished_index()] = True
+ end_flag = end_flag[indice]
+ mean_target_q = np.mean(target_q, -1) if len(target_q.shape) > 1 else target_q
+ _target_q = rew + gamma * mean_target_q * (1 - end_flag)
+ target_q = np.repeat(_target_q[..., None], self.num_branches, axis=-1)
+ target_q = np.repeat(target_q[..., None], self.max_action_num, axis=-1)
+
+ batch.returns = to_torch_as(target_q, target_q_torch)
+ if hasattr(batch, "weight"): # prio buffer update
+ batch.weight = to_torch_as(batch.weight, target_q_torch)
+ return batch
+
+ def process_fn(
+ self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
+ ) -> Batch:
+ """Compute the 1-step return for BDQ targets."""
+ return self._compute_return(batch, buffer, indices)
+
+ def forward(
+ self,
+ batch: Batch,
+ state: Optional[Union[Dict, Batch, np.ndarray]] = None,
+ model: str = "model",
+ input: str = "obs",
+ **kwargs: Any,
+ ) -> Batch:
+ model = getattr(self, model)
+ obs = batch[input]
+ obs_next = obs.obs if hasattr(obs, "obs") else obs
+ logits, hidden = model(obs_next, state=state, info=batch.info)
+ act = to_numpy(logits.max(dim=-1)[1])
+ return Batch(logits=logits, act=act, state=hidden)
+
+ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
+ if self._target and self._iter % self._freq == 0:
+ self.sync_weight()
+ self.optim.zero_grad()
+ weight = batch.pop("weight", 1.0)
+ act = to_torch(batch.act, dtype=torch.long, device=batch.returns.device)
+ q = self(batch).logits
+ act_mask = torch.zeros_like(q)
+ act_mask = act_mask.scatter_(-1, act.unsqueeze(-1), 1)
+ act_q = q * act_mask
+ returns = batch.returns
+ returns = returns * act_mask
+ td_error = returns - act_q
+ loss = (td_error.pow(2).sum(-1).mean(-1) * weight).mean()
+ batch.weight = td_error.sum(-1).sum(-1) # prio-buffer
+ loss.backward()
+ self.optim.step()
+ self._iter += 1
+ return {"loss": loss.item()}
+
+ def exploration_noise(
+ self,
+ act: Union[np.ndarray, Batch],
+ batch: Batch,
+ ) -> Union[np.ndarray, Batch]:
+ if isinstance(act, np.ndarray) and not np.isclose(self.eps, 0.0):
+ bsz = len(act)
+ rand_mask = np.random.rand(bsz) < self.eps
+ rand_act = np.random.randint(
+ low=0, high=self.max_action_num, size=(bsz, act.shape[-1])
+ )
+ if hasattr(batch.obs, "mask"):
+ rand_act += batch.obs.mask
+ act[rand_mask] = rand_act[rand_mask]
+ return act
diff --git a/tianshou/policy/modelfree/c51.py b/tianshou/policy/modelfree/c51.py
new file mode 100644
index 0000000..3ebc1cf
--- /dev/null
+++ b/tianshou/policy/modelfree/c51.py
@@ -0,0 +1,108 @@
+from typing import Any, Dict, Optional
+
+import numpy as np
+import torch
+
+from tianshou.data import Batch, ReplayBuffer
+from tianshou.policy import DQNPolicy
+
+
+class C51Policy(DQNPolicy):
+ """Implementation of Categorical Deep Q-Network. arXiv:1707.06887.
+
+ :param torch.nn.Module model: a model following the rules in
+ :class:`~tianshou.policy.BasePolicy`. (s -> logits)
+ :param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
+ :param float discount_factor: in [0, 1].
+ :param int num_atoms: the number of atoms in the support set of the
+ value distribution. Default to 51.
+ :param float v_min: the value of the smallest atom in the support set.
+ Default to -10.0.
+ :param float v_max: the value of the largest atom in the support set.
+ Default to 10.0.
+ :param int estimation_step: the number of steps to look ahead. Default to 1.
+ :param int target_update_freq: the target network update frequency (0 if
+ you do not use the target network). Default to 0.
+ :param bool reward_normalization: normalize the reward to Normal(0, 1).
+ Default to False.
+ :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
+ optimizer in each policy.update(). Default to None (no lr_scheduler).
+
+ .. seealso::
+
+ Please refer to :class:`~tianshou.policy.DQNPolicy` for more detailed
+ explanation.
+ """
+
+ def __init__(
+ self,
+ model: torch.nn.Module,
+ optim: torch.optim.Optimizer,
+ discount_factor: float = 0.99,
+ num_atoms: int = 51,
+ v_min: float = -10.0,
+ v_max: float = 10.0,
+ estimation_step: int = 1,
+ target_update_freq: int = 0,
+ reward_normalization: bool = False,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(
+ model, optim, discount_factor, estimation_step, target_update_freq,
+ reward_normalization, **kwargs
+ )
+ assert num_atoms > 1, "num_atoms should be greater than 1"
+ assert v_min < v_max, "v_max should be larger than v_min"
+ self._num_atoms = num_atoms
+ self._v_min = v_min
+ self._v_max = v_max
+ self.support = torch.nn.Parameter(
+ torch.linspace(self._v_min, self._v_max, self._num_atoms),
+ requires_grad=False,
+ )
+ self.delta_z = (v_max - v_min) / (num_atoms - 1)
+
+ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
+ return self.support.repeat(len(indices), 1) # shape: [bsz, num_atoms]
+
+ def compute_q_value(
+ self, logits: torch.Tensor, mask: Optional[np.ndarray]
+ ) -> torch.Tensor:
+ return super().compute_q_value((logits * self.support).sum(2), mask)
+
+ def _target_dist(self, batch: Batch) -> torch.Tensor:
+ if self._target:
+ act = self(batch, input="obs_next").act
+ next_dist = self(batch, model="model_old", input="obs_next").logits
+ else:
+ next_batch = self(batch, input="obs_next")
+ act = next_batch.act
+ next_dist = next_batch.logits
+ next_dist = next_dist[np.arange(len(act)), act, :]
+ target_support = batch.returns.clamp(self._v_min, self._v_max)
+ # An amazing trick for calculating the projection gracefully.
+ # ref: https://github.com/ShangtongZhang/DeepRL
+ target_dist = (
+ 1 - (target_support.unsqueeze(1) - self.support.view(1, -1, 1)).abs() /
+ self.delta_z
+ ).clamp(0, 1) * next_dist.unsqueeze(1)
+ return target_dist.sum(-1)
+
+ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
+ if self._target and self._iter % self._freq == 0:
+ self.sync_weight()
+ self.optim.zero_grad()
+ with torch.no_grad():
+ target_dist = self._target_dist(batch)
+ weight = batch.pop("weight", 1.0)
+ curr_dist = self(batch).logits
+ act = batch.act
+ curr_dist = curr_dist[np.arange(len(act)), act, :]
+ cross_entropy = -(target_dist * torch.log(curr_dist + 1e-8)).sum(1)
+ loss = (cross_entropy * weight).mean()
+ # ref: https://github.com/Kaixhin/Rainbow/blob/master/agent.py L94-100
+ batch.weight = cross_entropy.detach() # prio-buffer
+ loss.backward()
+ self.optim.step()
+ self._iter += 1
+ return {"loss": loss.item()}
diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py
new file mode 100644
index 0000000..0637d05
--- /dev/null
+++ b/tianshou/policy/modelfree/ddpg.py
@@ -0,0 +1,183 @@
+import warnings
+from copy import deepcopy
+from typing import Any, Dict, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from tianshou.data import Batch, ReplayBuffer
+from tianshou.exploration import BaseNoise, GaussianNoise
+from tianshou.policy import BasePolicy
+
+
+class DDPGPolicy(BasePolicy):
+ """Implementation of Deep Deterministic Policy Gradient. arXiv:1509.02971.
+
+ :param torch.nn.Module actor: the actor network following the rules in
+ :class:`~tianshou.policy.BasePolicy`. (s -> logits)
+ :param torch.optim.Optimizer actor_optim: the optimizer for actor network.
+ :param torch.nn.Module critic: the critic network. (s, a -> Q(s, a))
+ :param torch.optim.Optimizer critic_optim: the optimizer for critic network.
+ :param float tau: param for soft update of the target network. Default to 0.005.
+ :param float gamma: discount factor, in [0, 1]. Default to 0.99.
+ :param BaseNoise exploration_noise: the exploration noise,
+ add to the action. Default to ``GaussianNoise(sigma=0.1)``.
+ :param bool reward_normalization: normalize the reward to Normal(0, 1),
+ Default to False.
+ :param int estimation_step: the number of steps to look ahead. Default to 1.
+ :param bool action_scaling: whether to map actions from range [-1, 1] to range
+ [action_spaces.low, action_spaces.high]. Default to True.
+ :param str action_bound_method: method to bound action to range [-1, 1], can be
+ either "clip" (for simply clipping the action) or empty string for no bounding.
+ Default to "clip".
+ :param Optional[gym.Space] action_space: env's action space, mandatory if you want
+ to use option "action_scaling" or "action_bound_method". Default to None.
+ :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
+ optimizer in each policy.update(). Default to None (no lr_scheduler).
+
+ .. seealso::
+
+ Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
+ explanation.
+ """
+
+ def __init__(
+ self,
+ actor: Optional[torch.nn.Module],
+ actor_optim: Optional[torch.optim.Optimizer],
+ critic: Optional[torch.nn.Module],
+ critic_optim: Optional[torch.optim.Optimizer],
+ tau: float = 0.005,
+ gamma: float = 0.99,
+ exploration_noise: Optional[BaseNoise] = GaussianNoise(sigma=0.1),
+ reward_normalization: bool = False,
+ estimation_step: int = 1,
+ action_scaling: bool = True,
+ action_bound_method: str = "clip",
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(
+ action_scaling=action_scaling,
+ action_bound_method=action_bound_method,
+ **kwargs
+ )
+ assert action_bound_method != "tanh", "tanh mapping is not supported" \
+ "in policies where action is used as input of critic , because" \
+ "raw action in range (-inf, inf) will cause instability in training"
+ if actor is not None and actor_optim is not None:
+ self.actor: torch.nn.Module = actor
+ self.actor_old = deepcopy(actor)
+ self.actor_old.eval()
+ self.actor_optim: torch.optim.Optimizer = actor_optim
+ if critic is not None and critic_optim is not None:
+ self.critic: torch.nn.Module = critic
+ self.critic_old = deepcopy(critic)
+ self.critic_old.eval()
+ self.critic_optim: torch.optim.Optimizer = critic_optim
+ assert 0.0 <= tau <= 1.0, "tau should be in [0, 1]"
+ self.tau = tau
+ assert 0.0 <= gamma <= 1.0, "gamma should be in [0, 1]"
+ self._gamma = gamma
+ self._noise = exploration_noise
+ # it is only a little difference to use GaussianNoise
+ # self.noise = OUNoise()
+ self._rew_norm = reward_normalization
+ self._n_step = estimation_step
+
+ def set_exp_noise(self, noise: Optional[BaseNoise]) -> None:
+ """Set the exploration noise."""
+ self._noise = noise
+
+ def train(self, mode: bool = True) -> "DDPGPolicy":
+ """Set the module in training mode, except for the target network."""
+ self.training = mode
+ self.actor.train(mode)
+ self.critic.train(mode)
+ return self
+
+ def sync_weight(self) -> None:
+ """Soft-update the weight for the target network."""
+ self.soft_update(self.actor_old, self.actor, self.tau)
+ self.soft_update(self.critic_old, self.critic, self.tau)
+
+ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
+ batch = buffer[indices] # batch.obs_next: s_{t+n}
+ target_q = self.critic_old(
+ batch.obs_next,
+ self(batch, model='actor_old', input='obs_next').act
+ )
+ return target_q
+
+ def process_fn(
+ self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
+ ) -> Batch:
+ batch = self.compute_nstep_return(
+ batch, buffer, indices, self._target_q, self._gamma, self._n_step,
+ self._rew_norm
+ )
+ return batch
+
+ def forward(
+ self,
+ batch: Batch,
+ state: Optional[Union[dict, Batch, np.ndarray]] = None,
+ model: str = "actor",
+ input: str = "obs",
+ **kwargs: Any,
+ ) -> Batch:
+ """Compute action over the given batch data.
+
+ :return: A :class:`~tianshou.data.Batch` which has 2 keys:
+
+ * ``act`` the action.
+ * ``state`` the hidden state.
+
+ .. seealso::
+
+ Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
+ more detailed explanation.
+ """
+ model = getattr(self, model)
+ obs = batch[input]
+ actions, hidden = model(obs, state=state, info=batch.info)
+ return Batch(act=actions, state=hidden)
+
+ @staticmethod
+ def _mse_optimizer(
+ batch: Batch, critic: torch.nn.Module, optimizer: torch.optim.Optimizer
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """A simple wrapper script for updating critic network."""
+ weight = getattr(batch, "weight", 1.0)
+ current_q = critic(batch.obs, batch.act).flatten()
+ target_q = batch.returns.flatten()
+ td = current_q - target_q
+ # critic_loss = F.mse_loss(current_q1, target_q)
+ critic_loss = (td.pow(2) * weight).mean()
+ optimizer.zero_grad()
+ critic_loss.backward()
+ optimizer.step()
+ return td, critic_loss
+
+ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
+ # critic
+ td, critic_loss = self._mse_optimizer(batch, self.critic, self.critic_optim)
+ batch.weight = td # prio-buffer
+ # actor
+ actor_loss = -self.critic(batch.obs, self(batch).act).mean()
+ self.actor_optim.zero_grad()
+ actor_loss.backward()
+ self.actor_optim.step()
+ self.sync_weight()
+ return {
+ "loss/actor": actor_loss.item(),
+ "loss/critic": critic_loss.item(),
+ }
+
+ def exploration_noise(self, act: Union[np.ndarray, Batch],
+ batch: Batch) -> Union[np.ndarray, Batch]:
+ if self._noise is None:
+ return act
+ if isinstance(act, np.ndarray):
+ return act + self._noise(act.shape)
+ warnings.warn("Cannot add exploration noise to non-numpy_array action.")
+ return act
diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py
new file mode 100644
index 0000000..28a2cb2
--- /dev/null
+++ b/tianshou/policy/modelfree/discrete_sac.py
@@ -0,0 +1,161 @@
+from typing import Any, Dict, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from torch.distributions import Categorical
+
+from tianshou.data import Batch, ReplayBuffer, to_torch
+from tianshou.policy import SACPolicy
+
+
+class DiscreteSACPolicy(SACPolicy):
+ """Implementation of SAC for Discrete Action Settings. arXiv:1910.07207.
+
+ :param torch.nn.Module actor: the actor network following the rules in
+ :class:`~tianshou.policy.BasePolicy`. (s -> logits)
+ :param torch.optim.Optimizer actor_optim: the optimizer for actor network.
+ :param torch.nn.Module critic1: the first critic network. (s -> Q(s))
+ :param torch.optim.Optimizer critic1_optim: the optimizer for the first
+ critic network.
+ :param torch.nn.Module critic2: the second critic network. (s -> Q(s))
+ :param torch.optim.Optimizer critic2_optim: the optimizer for the second
+ critic network.
+ :param float tau: param for soft update of the target network. Default to 0.005.
+ :param float gamma: discount factor, in [0, 1]. Default to 0.99.
+ :param (float, torch.Tensor, torch.optim.Optimizer) or float alpha: entropy
+ regularization coefficient. Default to 0.2.
+ If a tuple (target_entropy, log_alpha, alpha_optim) is provided, the
+ alpha is automatically tuned.
+ :param bool reward_normalization: normalize the reward to Normal(0, 1).
+ Default to False.
+ :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
+ optimizer in each policy.update(). Default to None (no lr_scheduler).
+
+ .. seealso::
+
+ Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
+ explanation.
+ """
+
+ def __init__(
+ self,
+ actor: torch.nn.Module,
+ actor_optim: torch.optim.Optimizer,
+ critic1: torch.nn.Module,
+ critic1_optim: torch.optim.Optimizer,
+ critic2: torch.nn.Module,
+ critic2_optim: torch.optim.Optimizer,
+ tau: float = 0.005,
+ gamma: float = 0.99,
+ alpha: Union[float, Tuple[float, torch.Tensor, torch.optim.Optimizer]] = 0.2,
+ reward_normalization: bool = False,
+ estimation_step: int = 1,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(
+ actor,
+ actor_optim,
+ critic1,
+ critic1_optim,
+ critic2,
+ critic2_optim,
+ tau,
+ gamma,
+ alpha,
+ reward_normalization,
+ estimation_step,
+ action_scaling=False,
+ action_bound_method="",
+ **kwargs
+ )
+ self._alpha: Union[float, torch.Tensor]
+
+ def forward( # type: ignore
+ self,
+ batch: Batch,
+ state: Optional[Union[dict, Batch, np.ndarray]] = None,
+ input: str = "obs",
+ **kwargs: Any,
+ ) -> Batch:
+ obs = batch[input]
+ logits, hidden = self.actor(obs, state=state, info=batch.info)
+ dist = Categorical(logits=logits)
+ if self._deterministic_eval and not self.training:
+ act = logits.argmax(axis=-1)
+ else:
+ act = dist.sample()
+ return Batch(logits=logits, act=act, state=hidden, dist=dist)
+
+ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
+ batch = buffer[indices] # batch.obs: s_{t+n}
+ obs_next_result = self(batch, input="obs_next")
+ dist = obs_next_result.dist
+ target_q = dist.probs * torch.min(
+ self.critic1_old(batch.obs_next),
+ self.critic2_old(batch.obs_next),
+ )
+ target_q = target_q.sum(dim=-1) + self._alpha * dist.entropy()
+ return target_q
+
+ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
+ weight = batch.pop("weight", 1.0)
+ target_q = batch.returns.flatten()
+ act = to_torch(
+ batch.act[:, np.newaxis], device=target_q.device, dtype=torch.long
+ )
+
+ # critic 1
+ current_q1 = self.critic1(batch.obs).gather(1, act).flatten()
+ td1 = current_q1 - target_q
+ critic1_loss = (td1.pow(2) * weight).mean()
+
+ self.critic1_optim.zero_grad()
+ critic1_loss.backward()
+ self.critic1_optim.step()
+
+ # critic 2
+ current_q2 = self.critic2(batch.obs).gather(1, act).flatten()
+ td2 = current_q2 - target_q
+ critic2_loss = (td2.pow(2) * weight).mean()
+
+ self.critic2_optim.zero_grad()
+ critic2_loss.backward()
+ self.critic2_optim.step()
+ batch.weight = (td1 + td2) / 2.0 # prio-buffer
+
+ # actor
+ dist = self(batch).dist
+ entropy = dist.entropy()
+ with torch.no_grad():
+ current_q1a = self.critic1(batch.obs)
+ current_q2a = self.critic2(batch.obs)
+ q = torch.min(current_q1a, current_q2a)
+ actor_loss = -(self._alpha * entropy + (dist.probs * q).sum(dim=-1)).mean()
+ self.actor_optim.zero_grad()
+ actor_loss.backward()
+ self.actor_optim.step()
+
+ if self._is_auto_alpha:
+ log_prob = -entropy.detach() + self._target_entropy
+ alpha_loss = -(self._log_alpha * log_prob).mean()
+ self._alpha_optim.zero_grad()
+ alpha_loss.backward()
+ self._alpha_optim.step()
+ self._alpha = self._log_alpha.detach().exp()
+
+ self.sync_weight()
+
+ result = {
+ "loss/actor": actor_loss.item(),
+ "loss/critic1": critic1_loss.item(),
+ "loss/critic2": critic2_loss.item(),
+ }
+ if self._is_auto_alpha:
+ result["loss/alpha"] = alpha_loss.item()
+ result["alpha"] = self._alpha.item() # type: ignore
+
+ return result
+
+ def exploration_noise(self, act: Union[np.ndarray, Batch],
+ batch: Batch) -> Union[np.ndarray, Batch]:
+ return act
diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py
new file mode 100644
index 0000000..7107be9
--- /dev/null
+++ b/tianshou/policy/modelfree/dqn.py
@@ -0,0 +1,204 @@
+from copy import deepcopy
+from typing import Any, Dict, Optional, Union
+
+import numpy as np
+import torch
+
+from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as
+from tianshou.policy import BasePolicy
+
+
+class DQNPolicy(BasePolicy):
+ """Implementation of Deep Q Network. arXiv:1312.5602.
+
+ Implementation of Double Q-Learning. arXiv:1509.06461.
+
+ Implementation of Dueling DQN. arXiv:1511.06581 (the dueling DQN is
+ implemented in the network side, not here).
+
+ :param torch.nn.Module model: a model following the rules in
+ :class:`~tianshou.policy.BasePolicy`. (s -> logits)
+ :param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
+ :param float discount_factor: in [0, 1].
+ :param int estimation_step: the number of steps to look ahead. Default to 1.
+ :param int target_update_freq: the target network update frequency (0 if
+ you do not use the target network). Default to 0.
+ :param bool reward_normalization: normalize the reward to Normal(0, 1).
+ Default to False.
+ :param bool is_double: use double dqn. Default to True.
+ :param bool clip_loss_grad: clip the gradient of the loss in accordance
+ with nature14236; this amounts to using the Huber loss instead of
+ the MSE loss. Default to False.
+ :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
+ optimizer in each policy.update(). Default to None (no lr_scheduler).
+
+ .. seealso::
+
+ Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
+ explanation.
+ """
+
+ def __init__(
+ self,
+ model: torch.nn.Module,
+ optim: torch.optim.Optimizer,
+ discount_factor: float = 0.99,
+ estimation_step: int = 1,
+ target_update_freq: int = 0,
+ reward_normalization: bool = False,
+ is_double: bool = True,
+ clip_loss_grad: bool = False,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.model = model
+ self.optim = optim
+ self.eps = 0.0
+ assert 0.0 <= discount_factor <= 1.0, "discount factor should be in [0, 1]"
+ self._gamma = discount_factor
+ assert estimation_step > 0, "estimation_step should be greater than 0"
+ self._n_step = estimation_step
+ self._target = target_update_freq > 0
+ self._freq = target_update_freq
+ self._iter = 0
+ if self._target:
+ self.model_old = deepcopy(self.model)
+ self.model_old.eval()
+ self._rew_norm = reward_normalization
+ self._is_double = is_double
+ self._clip_loss_grad = clip_loss_grad
+
+ def set_eps(self, eps: float) -> None:
+ """Set the eps for epsilon-greedy exploration."""
+ self.eps = eps
+
+ def train(self, mode: bool = True) -> "DQNPolicy":
+ """Set the module in training mode, except for the target network."""
+ self.training = mode
+ self.model.train(mode)
+ return self
+
+ def sync_weight(self) -> None:
+ """Synchronize the weight for the target network."""
+ self.model_old.load_state_dict(self.model.state_dict())
+
+ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
+ batch = buffer[indices] # batch.obs_next: s_{t+n}
+ result = self(batch, input="obs_next")
+ if self._target:
+ # target_Q = Q_old(s_, argmax(Q_new(s_, *)))
+ target_q = self(batch, model="model_old", input="obs_next").logits
+ else:
+ target_q = result.logits
+ if self._is_double:
+ return target_q[np.arange(len(result.act)), result.act]
+ else: # Nature DQN, over estimate
+ return target_q.max(dim=1)[0]
+
+ def process_fn(
+ self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
+ ) -> Batch:
+ """Compute the n-step return for Q-learning targets.
+
+ More details can be found at
+ :meth:`~tianshou.policy.BasePolicy.compute_nstep_return`.
+ """
+ batch = self.compute_nstep_return(
+ batch, buffer, indices, self._target_q, self._gamma, self._n_step,
+ self._rew_norm
+ )
+ return batch
+
+ def compute_q_value(
+ self, logits: torch.Tensor, mask: Optional[np.ndarray]
+ ) -> torch.Tensor:
+ """Compute the q value based on the network's raw output and action mask."""
+ if mask is not None:
+ # the masked q value should be smaller than logits.min()
+ min_value = logits.min() - logits.max() - 1.0
+ logits = logits + to_torch_as(1 - mask, logits) * min_value
+ return logits
+
+ def forward(
+ self,
+ batch: Batch,
+ state: Optional[Union[dict, Batch, np.ndarray]] = None,
+ model: str = "model",
+ input: str = "obs",
+ **kwargs: Any,
+ ) -> Batch:
+ """Compute action over the given batch data.
+
+ If you need to mask the action, please add a "mask" into batch.obs, for
+ example, if we have an environment that has "0/1/2" three actions:
+ ::
+
+ batch == Batch(
+ obs=Batch(
+ obs="original obs, with batch_size=1 for demonstration",
+ mask=np.array([[False, True, False]]),
+ # action 1 is available
+ # action 0 and 2 are unavailable
+ ),
+ ...
+ )
+
+ :return: A :class:`~tianshou.data.Batch` which has 3 keys:
+
+ * ``act`` the action.
+ * ``logits`` the network's raw output.
+ * ``state`` the hidden state.
+
+ .. seealso::
+
+ Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
+ more detailed explanation.
+ """
+ model = getattr(self, model)
+ obs = batch[input]
+ obs_next = obs.obs if hasattr(obs, "obs") else obs
+ # logits, hidden = model(obs_next, state=state, info=batch.info)
+ logits, hidden = model(obs, state=state, info=batch.info)
+ q = self.compute_q_value(logits, getattr(obs, "mask", None))
+ if not hasattr(self, "max_action_num"):
+ self.max_action_num = q.shape[1]
+ act = to_numpy(q.max(dim=1)[1])
+ return Batch(logits=logits, act=act, state=hidden)
+
+ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
+ if self._target and self._iter % self._freq == 0:
+ self.sync_weight()
+ self.optim.zero_grad()
+ weight = batch.pop("weight", 1.0)
+ q = self(batch).logits
+ q = q[np.arange(len(q)), batch.act]
+ returns = to_torch_as(batch.returns.flatten(), q)
+ td_error = returns - q
+
+ if self._clip_loss_grad:
+ y = q.reshape(-1, 1)
+ t = returns.reshape(-1, 1)
+ loss = torch.nn.functional.huber_loss(y, t, reduction="mean")
+ else:
+ loss = (td_error.pow(2) * weight).mean()
+
+ batch.weight = td_error # prio-buffer
+ loss.backward()
+ self.optim.step()
+ self._iter += 1
+ return {"loss": loss.item()}
+
+ def exploration_noise(
+ self,
+ act: Union[np.ndarray, Batch],
+ batch: Batch,
+ ) -> Union[np.ndarray, Batch]:
+ if isinstance(act, np.ndarray) and not np.isclose(self.eps, 0.0):
+ bsz = len(act)
+ rand_mask = np.random.rand(bsz) < self.eps
+ q = np.random.rand(bsz, self.max_action_num) # [0, 1]
+ if hasattr(batch.obs, "mask"):
+ q += batch.obs.mask
+ rand_act = q.argmax(axis=1)
+ act[rand_mask] = rand_act[rand_mask]
+ return act
diff --git a/tianshou/policy/modelfree/fqf.py b/tianshou/policy/modelfree/fqf.py
new file mode 100644
index 0000000..9eee122
--- /dev/null
+++ b/tianshou/policy/modelfree/fqf.py
@@ -0,0 +1,177 @@
+from typing import Any, Dict, Optional, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from tianshou.data import Batch, ReplayBuffer, to_numpy
+from tianshou.policy import DQNPolicy, QRDQNPolicy
+from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction
+
+
+class FQFPolicy(QRDQNPolicy):
+ """Implementation of Fully-parameterized Quantile Function. arXiv:1911.02140.
+
+ :param torch.nn.Module model: a model following the rules in
+ :class:`~tianshou.policy.BasePolicy`. (s -> logits)
+ :param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
+ :param FractionProposalNetwork fraction_model: a FractionProposalNetwork for
+ proposing fractions/quantiles given state.
+ :param torch.optim.Optimizer fraction_optim: a torch.optim for optimizing
+ the fraction model above.
+ :param float discount_factor: in [0, 1].
+ :param int num_fractions: the number of fractions to use. Default to 32.
+ :param float ent_coef: the coefficient for entropy loss. Default to 0.
+ :param int estimation_step: the number of steps to look ahead. Default to 1.
+ :param int target_update_freq: the target network update frequency (0 if
+ you do not use the target network).
+ :param bool reward_normalization: normalize the reward to Normal(0, 1).
+ Default to False.
+ :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
+ optimizer in each policy.update(). Default to None (no lr_scheduler).
+
+ .. seealso::
+
+ Please refer to :class:`~tianshou.policy.QRDQNPolicy` for more detailed
+ explanation.
+ """
+
+ def __init__(
+ self,
+ model: FullQuantileFunction,
+ optim: torch.optim.Optimizer,
+ fraction_model: FractionProposalNetwork,
+ fraction_optim: torch.optim.Optimizer,
+ discount_factor: float = 0.99,
+ num_fractions: int = 32,
+ ent_coef: float = 0.0,
+ estimation_step: int = 1,
+ target_update_freq: int = 0,
+ reward_normalization: bool = False,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(
+ model, optim, discount_factor, num_fractions, estimation_step,
+ target_update_freq, reward_normalization, **kwargs
+ )
+ self.propose_model = fraction_model
+ self._ent_coef = ent_coef
+ self._fraction_optim = fraction_optim
+
+ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
+ batch = buffer[indices] # batch.obs_next: s_{t+n}
+ if self._target:
+ result = self(batch, input="obs_next")
+ act, fractions = result.act, result.fractions
+ next_dist = self(
+ batch, model="model_old", input="obs_next", fractions=fractions
+ ).logits
+ else:
+ next_batch = self(batch, input="obs_next")
+ act = next_batch.act
+ next_dist = next_batch.logits
+ next_dist = next_dist[np.arange(len(act)), act, :]
+ return next_dist # shape: [bsz, num_quantiles]
+
+ def forward(
+ self,
+ batch: Batch,
+ state: Optional[Union[dict, Batch, np.ndarray]] = None,
+ model: str = "model",
+ input: str = "obs",
+ fractions: Optional[Batch] = None,
+ **kwargs: Any,
+ ) -> Batch:
+ model = getattr(self, model)
+ obs = batch[input]
+ obs_next = obs.obs if hasattr(obs, "obs") else obs
+ if fractions is None:
+ (logits, fractions, quantiles_tau), hidden = model(
+ obs_next,
+ propose_model=self.propose_model,
+ state=state,
+ info=batch.info
+ )
+ else:
+ (logits, _, quantiles_tau), hidden = model(
+ obs_next,
+ propose_model=self.propose_model,
+ fractions=fractions,
+ state=state,
+ info=batch.info
+ )
+ weighted_logits = (fractions.taus[:, 1:] -
+ fractions.taus[:, :-1]).unsqueeze(1) * logits
+ q = DQNPolicy.compute_q_value(
+ self, weighted_logits.sum(2), getattr(obs, "mask", None)
+ )
+ if not hasattr(self, "max_action_num"):
+ self.max_action_num = q.shape[1]
+ act = to_numpy(q.max(dim=1)[1])
+ return Batch(
+ logits=logits,
+ act=act,
+ state=hidden,
+ fractions=fractions,
+ quantiles_tau=quantiles_tau
+ )
+
+ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
+ if self._target and self._iter % self._freq == 0:
+ self.sync_weight()
+ weight = batch.pop("weight", 1.0)
+ out = self(batch)
+ curr_dist_orig = out.logits
+ taus, tau_hats = out.fractions.taus, out.fractions.tau_hats
+ act = batch.act
+ curr_dist = curr_dist_orig[np.arange(len(act)), act, :].unsqueeze(2)
+ target_dist = batch.returns.unsqueeze(1)
+ # calculate each element's difference between curr_dist and target_dist
+ dist_diff = F.smooth_l1_loss(target_dist, curr_dist, reduction="none")
+ huber_loss = (
+ dist_diff * (
+ tau_hats.unsqueeze(2) -
+ (target_dist - curr_dist).detach().le(0.).float()
+ ).abs()
+ ).sum(-1).mean(1)
+ quantile_loss = (huber_loss * weight).mean()
+ # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/
+ # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130
+ batch.weight = dist_diff.detach().abs().sum(-1).mean(1) # prio-buffer
+ # calculate fraction loss
+ with torch.no_grad():
+ sa_quantile_hats = curr_dist_orig[np.arange(len(act)), act, :]
+ sa_quantiles = out.quantiles_tau[np.arange(len(act)), act, :]
+ # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/
+ # blob/master/fqf_iqn_qrdqn/agent/fqf_agent.py L169
+ values_1 = sa_quantiles - sa_quantile_hats[:, :-1]
+ signs_1 = sa_quantiles > torch.cat(
+ [sa_quantile_hats[:, :1], sa_quantiles[:, :-1]], dim=1
+ )
+
+ values_2 = sa_quantiles - sa_quantile_hats[:, 1:]
+ signs_2 = sa_quantiles < torch.cat(
+ [sa_quantiles[:, 1:], sa_quantile_hats[:, -1:]], dim=1
+ )
+
+ gradient_of_taus = (
+ torch.where(signs_1, values_1, -values_1) +
+ torch.where(signs_2, values_2, -values_2)
+ )
+ fraction_loss = (gradient_of_taus * taus[:, 1:-1]).sum(1).mean()
+ # calculate entropy loss
+ entropy_loss = out.fractions.entropies.mean()
+ fraction_entropy_loss = fraction_loss - self._ent_coef * entropy_loss
+ self._fraction_optim.zero_grad()
+ fraction_entropy_loss.backward(retain_graph=True)
+ self._fraction_optim.step()
+ self.optim.zero_grad()
+ quantile_loss.backward()
+ self.optim.step()
+ self._iter += 1
+ return {
+ "loss": quantile_loss.item() + fraction_entropy_loss.item(),
+ "loss/quantile": quantile_loss.item(),
+ "loss/fraction": fraction_loss.item(),
+ "loss/entropy": entropy_loss.item()
+ }
diff --git a/tianshou/policy/modelfree/iqn.py b/tianshou/policy/modelfree/iqn.py
new file mode 100644
index 0000000..74b8d78
--- /dev/null
+++ b/tianshou/policy/modelfree/iqn.py
@@ -0,0 +1,112 @@
+from typing import Any, Dict, Optional, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from tianshou.data import Batch, to_numpy
+from tianshou.policy import QRDQNPolicy
+
+
+class IQNPolicy(QRDQNPolicy):
+ """Implementation of Implicit Quantile Network. arXiv:1806.06923.
+
+ :param torch.nn.Module model: a model following the rules in
+ :class:`~tianshou.policy.BasePolicy`. (s -> logits)
+ :param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
+ :param float discount_factor: in [0, 1].
+ :param int sample_size: the number of samples for policy evaluation.
+ Default to 32.
+ :param int online_sample_size: the number of samples for online model
+ in training. Default to 8.
+ :param int target_sample_size: the number of samples for target model
+ in training. Default to 8.
+ :param int estimation_step: the number of steps to look ahead. Default to 1.
+ :param int target_update_freq: the target network update frequency (0 if
+ you do not use the target network).
+ :param bool reward_normalization: normalize the reward to Normal(0, 1).
+ Default to False.
+ :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
+ optimizer in each policy.update(). Default to None (no lr_scheduler).
+
+ .. seealso::
+
+ Please refer to :class:`~tianshou.policy.QRDQNPolicy` for more detailed
+ explanation.
+ """
+
+ def __init__(
+ self,
+ model: torch.nn.Module,
+ optim: torch.optim.Optimizer,
+ discount_factor: float = 0.99,
+ sample_size: int = 32,
+ online_sample_size: int = 8,
+ target_sample_size: int = 8,
+ estimation_step: int = 1,
+ target_update_freq: int = 0,
+ reward_normalization: bool = False,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(
+ model, optim, discount_factor, sample_size, estimation_step,
+ target_update_freq, reward_normalization, **kwargs
+ )
+ assert sample_size > 1, "sample_size should be greater than 1"
+ assert online_sample_size > 1, "online_sample_size should be greater than 1"
+ assert target_sample_size > 1, "target_sample_size should be greater than 1"
+ self._sample_size = sample_size # for policy eval
+ self._online_sample_size = online_sample_size
+ self._target_sample_size = target_sample_size
+
+ def forward(
+ self,
+ batch: Batch,
+ state: Optional[Union[dict, Batch, np.ndarray]] = None,
+ model: str = "model",
+ input: str = "obs",
+ **kwargs: Any,
+ ) -> Batch:
+ if model == "model_old":
+ sample_size = self._target_sample_size
+ elif self.training:
+ sample_size = self._online_sample_size
+ else:
+ sample_size = self._sample_size
+ model = getattr(self, model)
+ obs = batch[input]
+ obs_next = obs.obs if hasattr(obs, "obs") else obs
+ (logits, taus), hidden = model(
+ obs_next, sample_size=sample_size, state=state, info=batch.info
+ )
+ q = self.compute_q_value(logits, getattr(obs, "mask", None))
+ if not hasattr(self, "max_action_num"):
+ self.max_action_num = q.shape[1]
+ act = to_numpy(q.max(dim=1)[1])
+ return Batch(logits=logits, act=act, state=hidden, taus=taus)
+
+ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
+ if self._target and self._iter % self._freq == 0:
+ self.sync_weight()
+ self.optim.zero_grad()
+ weight = batch.pop("weight", 1.0)
+ action_batch = self(batch)
+ curr_dist, taus = action_batch.logits, action_batch.taus
+ act = batch.act
+ curr_dist = curr_dist[np.arange(len(act)), act, :].unsqueeze(2)
+ target_dist = batch.returns.unsqueeze(1)
+ # calculate each element's difference between curr_dist and target_dist
+ dist_diff = F.smooth_l1_loss(target_dist, curr_dist, reduction="none")
+ huber_loss = (
+ dist_diff *
+ (taus.unsqueeze(2) -
+ (target_dist - curr_dist).detach().le(0.).float()).abs()
+ ).sum(-1).mean(1)
+ loss = (huber_loss * weight).mean()
+ # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/
+ # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130
+ batch.weight = dist_diff.detach().abs().sum(-1).mean(1) # prio-buffer
+ loss.backward()
+ self.optim.step()
+ self._iter += 1
+ return {"loss": loss.item()}
diff --git a/tianshou/policy/modelfree/npg.py b/tianshou/policy/modelfree/npg.py
new file mode 100644
index 0000000..e3ab308
--- /dev/null
+++ b/tianshou/policy/modelfree/npg.py
@@ -0,0 +1,184 @@
+from typing import Any, Dict, List, Type
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+from torch.distributions import kl_divergence
+
+from tianshou.data import Batch, ReplayBuffer
+from tianshou.policy import A2CPolicy
+
+
+class NPGPolicy(A2CPolicy):
+ """Implementation of Natural Policy Gradient.
+
+ https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf
+
+ :param torch.nn.Module actor: the actor network following the rules in
+ :class:`~tianshou.policy.BasePolicy`. (s -> logits)
+ :param torch.nn.Module critic: the critic network. (s -> V(s))
+ :param torch.optim.Optimizer optim: the optimizer for actor and critic network.
+ :param dist_fn: distribution class for computing the action.
+ :type dist_fn: Type[torch.distributions.Distribution]
+ :param bool advantage_normalization: whether to do per mini-batch advantage
+ normalization. Default to True.
+ :param int optim_critic_iters: Number of times to optimize critic network per
+ update. Default to 5.
+ :param float gae_lambda: in [0, 1], param for Generalized Advantage Estimation.
+ Default to 0.95.
+ :param bool reward_normalization: normalize estimated values to have std close to
+ 1. Default to False.
+ :param int max_batchsize: the maximum size of the batch when computing GAE,
+ depends on the size of available memory and the memory cost of the
+ model; should be as large as possible within the memory constraint.
+ Default to 256.
+ :param bool action_scaling: whether to map actions from range [-1, 1] to range
+ [action_spaces.low, action_spaces.high]. Default to True.
+ :param str action_bound_method: method to bound action to range [-1, 1], can be
+ either "clip" (for simply clipping the action), "tanh" (for applying tanh
+ squashing) for now, or empty string for no bounding. Default to "clip".
+ :param Optional[gym.Space] action_space: env's action space, mandatory if you want
+ to use option "action_scaling" or "action_bound_method". Default to None.
+ :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
+ optimizer in each policy.update(). Default to None (no lr_scheduler).
+ :param bool deterministic_eval: whether to use deterministic action instead of
+ stochastic action sampled by the policy. Default to False.
+ """
+
+ def __init__(
+ self,
+ actor: torch.nn.Module,
+ critic: torch.nn.Module,
+ optim: torch.optim.Optimizer,
+ dist_fn: Type[torch.distributions.Distribution],
+ advantage_normalization: bool = True,
+ optim_critic_iters: int = 5,
+ actor_step_size: float = 0.5,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(actor, critic, optim, dist_fn, **kwargs)
+ del self._weight_vf, self._weight_ent, self._grad_norm
+ self._norm_adv = advantage_normalization
+ self._optim_critic_iters = optim_critic_iters
+ self._step_size = actor_step_size
+ # adjusts Hessian-vector product calculation for numerical stability
+ self._damping = 0.1
+
+ def process_fn(
+ self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
+ ) -> Batch:
+ batch = super().process_fn(batch, buffer, indices)
+ old_log_prob = []
+ with torch.no_grad():
+ for minibatch in batch.split(self._batch, shuffle=False, merge_last=True):
+ old_log_prob.append(self(minibatch).dist.log_prob(minibatch.act))
+ batch.logp_old = torch.cat(old_log_prob, dim=0)
+ if self._norm_adv:
+ batch.adv = (batch.adv - batch.adv.mean()) / batch.adv.std()
+ return batch
+
+ def learn( # type: ignore
+ self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
+ ) -> Dict[str, List[float]]:
+ actor_losses, vf_losses, kls = [], [], []
+ for _ in range(repeat):
+ for minibatch in batch.split(batch_size, merge_last=True):
+ # optimize actor
+ # direction: calculate villia gradient
+ dist = self(minibatch).dist
+ log_prob = dist.log_prob(minibatch.act)
+ log_prob = log_prob.reshape(log_prob.size(0), -1).transpose(0, 1)
+ actor_loss = -(log_prob * minibatch.adv).mean()
+ flat_grads = self._get_flat_grad(
+ actor_loss, self.actor, retain_graph=True
+ ).detach()
+
+ # direction: calculate natural gradient
+ with torch.no_grad():
+ old_dist = self(minibatch).dist
+
+ kl = kl_divergence(old_dist, dist).mean()
+ # calculate first order gradient of kl with respect to theta
+ flat_kl_grad = self._get_flat_grad(kl, self.actor, create_graph=True)
+ search_direction = -self._conjugate_gradients(
+ flat_grads, flat_kl_grad, nsteps=10
+ )
+
+ # step
+ with torch.no_grad():
+ flat_params = torch.cat(
+ [param.data.view(-1) for param in self.actor.parameters()]
+ )
+ new_flat_params = flat_params + self._step_size * search_direction
+ self._set_from_flat_params(self.actor, new_flat_params)
+ new_dist = self(minibatch).dist
+ kl = kl_divergence(old_dist, new_dist).mean()
+
+ # optimize citirc
+ for _ in range(self._optim_critic_iters):
+ value = self.critic(minibatch.obs).flatten()
+ vf_loss = F.mse_loss(minibatch.returns, value)
+ self.optim.zero_grad()
+ vf_loss.backward()
+ self.optim.step()
+
+ actor_losses.append(actor_loss.item())
+ vf_losses.append(vf_loss.item())
+ kls.append(kl.item())
+
+ return {
+ "loss/actor": actor_losses,
+ "loss/vf": vf_losses,
+ "kl": kls,
+ }
+
+ def _MVP(self, v: torch.Tensor, flat_kl_grad: torch.Tensor) -> torch.Tensor:
+ """Matrix vector product."""
+ # caculate second order gradient of kl with respect to theta
+ kl_v = (flat_kl_grad * v).sum()
+ flat_kl_grad_grad = self._get_flat_grad(kl_v, self.actor,
+ retain_graph=True).detach()
+ return flat_kl_grad_grad + v * self._damping
+
+ def _conjugate_gradients(
+ self,
+ minibatch: torch.Tensor,
+ flat_kl_grad: torch.Tensor,
+ nsteps: int = 10,
+ residual_tol: float = 1e-10
+ ) -> torch.Tensor:
+ x = torch.zeros_like(minibatch)
+ r, p = minibatch.clone(), minibatch.clone()
+ # Note: should be 'r, p = minibatch - MVP(x)', but for x=0, MVP(x)=0.
+ # Change if doing warm start.
+ rdotr = r.dot(r)
+ for _ in range(nsteps):
+ z = self._MVP(p, flat_kl_grad)
+ alpha = rdotr / p.dot(z)
+ x += alpha * p
+ r -= alpha * z
+ new_rdotr = r.dot(r)
+ if new_rdotr < residual_tol:
+ break
+ p = r + new_rdotr / rdotr * p
+ rdotr = new_rdotr
+ return x
+
+ def _get_flat_grad(
+ self, y: torch.Tensor, model: nn.Module, **kwargs: Any
+ ) -> torch.Tensor:
+ grads = torch.autograd.grad(y, model.parameters(), **kwargs) # type: ignore
+ return torch.cat([grad.reshape(-1) for grad in grads])
+
+ def _set_from_flat_params(
+ self, model: nn.Module, flat_params: torch.Tensor
+ ) -> nn.Module:
+ prev_ind = 0
+ for param in model.parameters():
+ flat_size = int(np.prod(list(param.size())))
+ param.data.copy_(
+ flat_params[prev_ind:prev_ind + flat_size].view(param.size())
+ )
+ prev_ind += flat_size
+ return model
diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py
new file mode 100644
index 0000000..4f0b1b1
--- /dev/null
+++ b/tianshou/policy/modelfree/pg.py
@@ -0,0 +1,139 @@
+from typing import Any, Dict, List, Optional, Type, Union
+
+import numpy as np
+import torch
+
+from tianshou.data import Batch, ReplayBuffer, to_torch, to_torch_as
+from tianshou.policy import BasePolicy
+from tianshou.utils import RunningMeanStd
+
+
+class PGPolicy(BasePolicy):
+ """Implementation of REINFORCE algorithm.
+
+ :param torch.nn.Module model: a model following the rules in
+ :class:`~tianshou.policy.BasePolicy`. (s -> logits)
+ :param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
+ :param dist_fn: distribution class for computing the action.
+ :type dist_fn: Type[torch.distributions.Distribution]
+ :param float discount_factor: in [0, 1]. Default to 0.99.
+ :param bool action_scaling: whether to map actions from range [-1, 1] to range
+ [action_spaces.low, action_spaces.high]. Default to True.
+ :param str action_bound_method: method to bound action to range [-1, 1], can be
+ either "clip" (for simply clipping the action), "tanh" (for applying tanh
+ squashing) for now, or empty string for no bounding. Default to "clip".
+ :param Optional[gym.Space] action_space: env's action space, mandatory if you want
+ to use option "action_scaling" or "action_bound_method". Default to None.
+ :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
+ optimizer in each policy.update(). Default to None (no lr_scheduler).
+ :param bool deterministic_eval: whether to use deterministic action instead of
+ stochastic action sampled by the policy. Default to False.
+
+ .. seealso::
+
+ Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
+ explanation.
+ """
+
+ def __init__(
+ self,
+ model: torch.nn.Module,
+ optim: torch.optim.Optimizer,
+ dist_fn: Type[torch.distributions.Distribution],
+ discount_factor: float = 0.99,
+ reward_normalization: bool = False,
+ action_scaling: bool = True,
+ action_bound_method: str = "clip",
+ deterministic_eval: bool = False,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(
+ action_scaling=action_scaling,
+ action_bound_method=action_bound_method,
+ **kwargs
+ )
+ self.actor = model
+ self.optim = optim
+ self.dist_fn = dist_fn
+ assert 0.0 <= discount_factor <= 1.0, "discount factor should be in [0, 1]"
+ self._gamma = discount_factor
+ self._rew_norm = reward_normalization
+ self.ret_rms = RunningMeanStd()
+ self._eps = 1e-8
+ self._deterministic_eval = deterministic_eval
+
+ def process_fn(
+ self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
+ ) -> Batch:
+ r"""Compute the discounted returns for each transition.
+
+ .. math::
+ G_t = \sum_{i=t}^T \gamma^{i-t}r_i
+
+ where :math:`T` is the terminal time step, :math:`\gamma` is the
+ discount factor, :math:`\gamma \in [0, 1]`.
+ """
+ v_s_ = np.full(indices.shape, self.ret_rms.mean)
+ unnormalized_returns, _ = self.compute_episodic_return(
+ batch, buffer, indices, v_s_=v_s_, gamma=self._gamma, gae_lambda=1.0
+ )
+ if self._rew_norm:
+ batch.returns = (unnormalized_returns - self.ret_rms.mean) / \
+ np.sqrt(self.ret_rms.var + self._eps)
+ self.ret_rms.update(unnormalized_returns)
+ else:
+ batch.returns = unnormalized_returns
+ return batch
+
+ def forward(
+ self,
+ batch: Batch,
+ state: Optional[Union[dict, Batch, np.ndarray]] = None,
+ **kwargs: Any,
+ ) -> Batch:
+ """Compute action over the given batch data.
+
+ :return: A :class:`~tianshou.data.Batch` which has 4 keys:
+
+ * ``act`` the action.
+ * ``logits`` the network's raw output.
+ * ``dist`` the action distribution.
+ * ``state`` the hidden state.
+
+ .. seealso::
+
+ Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
+ more detailed explanation.
+ """
+ logits, hidden = self.actor(batch.obs, state=state, info=batch.info)
+ if isinstance(logits, tuple):
+ dist = self.dist_fn(*logits)
+ else:
+ dist = self.dist_fn(logits)
+ if self._deterministic_eval and not self.training:
+ if self.action_type == "discrete":
+ act = logits.argmax(-1)
+ elif self.action_type == "continuous":
+ act = logits[0]
+ else:
+ act = dist.sample()
+ return Batch(logits=logits, act=act, state=hidden, dist=dist)
+
+ def learn( # type: ignore
+ self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
+ ) -> Dict[str, List[float]]:
+ losses = []
+ for _ in range(repeat):
+ for minibatch in batch.split(batch_size, merge_last=True):
+ self.optim.zero_grad()
+ result = self(minibatch)
+ dist = result.dist
+ act = to_torch_as(minibatch.act, result.act)
+ ret = to_torch(minibatch.returns, torch.float, result.act.device)
+ log_prob = dist.log_prob(act).reshape(len(ret), -1).transpose(0, 1)
+ loss = -(log_prob * ret).mean()
+ loss.backward()
+ self.optim.step()
+ losses.append(loss.item())
+
+ return {"loss": losses}
diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py
new file mode 100644
index 0000000..b8c1914
--- /dev/null
+++ b/tianshou/policy/modelfree/ppo.py
@@ -0,0 +1,156 @@
+from typing import Any, Dict, List, Optional, Type
+
+import numpy as np
+import torch
+from torch import nn
+
+from tianshou.data import Batch, ReplayBuffer, to_torch_as
+from tianshou.policy import A2CPolicy
+from tianshou.utils.net.common import ActorCritic
+
+
+class PPOPolicy(A2CPolicy):
+ r"""Implementation of Proximal Policy Optimization. arXiv:1707.06347.
+
+ :param torch.nn.Module actor: the actor network following the rules in
+ :class:`~tianshou.policy.BasePolicy`. (s -> logits)
+ :param torch.nn.Module critic: the critic network. (s -> V(s))
+ :param torch.optim.Optimizer optim: the optimizer for actor and critic network.
+ :param dist_fn: distribution class for computing the action.
+ :type dist_fn: Type[torch.distributions.Distribution]
+ :param float discount_factor: in [0, 1]. Default to 0.99.
+ :param float eps_clip: :math:`\epsilon` in :math:`L_{CLIP}` in the original
+ paper. Default to 0.2.
+ :param float dual_clip: a parameter c mentioned in arXiv:1912.09729 Equ. 5,
+ where c > 1 is a constant indicating the lower bound.
+ Default to 5.0 (set None if you do not want to use it).
+ :param bool value_clip: a parameter mentioned in arXiv:1811.02553v3 Sec. 4.1.
+ Default to True.
+ :param bool advantage_normalization: whether to do per mini-batch advantage
+ normalization. Default to True.
+ :param bool recompute_advantage: whether to recompute advantage every update
+ repeat according to https://arxiv.org/pdf/2006.05990.pdf Sec. 3.5.
+ Default to False.
+ :param float vf_coef: weight for value loss. Default to 0.5.
+ :param float ent_coef: weight for entropy loss. Default to 0.01.
+ :param float max_grad_norm: clipping gradients in back propagation. Default to
+ None.
+ :param float gae_lambda: in [0, 1], param for Generalized Advantage Estimation.
+ Default to 0.95.
+ :param bool reward_normalization: normalize estimated values to have std close
+ to 1, also normalize the advantage to Normal(0, 1). Default to False.
+ :param int max_batchsize: the maximum size of the batch when computing GAE,
+ depends on the size of available memory and the memory cost of the model;
+ should be as large as possible within the memory constraint. Default to 256.
+ :param bool action_scaling: whether to map actions from range [-1, 1] to range
+ [action_spaces.low, action_spaces.high]. Default to True.
+ :param str action_bound_method: method to bound action to range [-1, 1], can be
+ either "clip" (for simply clipping the action), "tanh" (for applying tanh
+ squashing) for now, or empty string for no bounding. Default to "clip".
+ :param Optional[gym.Space] action_space: env's action space, mandatory if you want
+ to use option "action_scaling" or "action_bound_method". Default to None.
+ :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
+ optimizer in each policy.update(). Default to None (no lr_scheduler).
+ :param bool deterministic_eval: whether to use deterministic action instead of
+ stochastic action sampled by the policy. Default to False.
+
+ .. seealso::
+
+ Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
+ explanation.
+ """
+
+ def __init__(
+ self,
+ actor: torch.nn.Module,
+ critic: torch.nn.Module,
+ optim: torch.optim.Optimizer,
+ dist_fn: Type[torch.distributions.Distribution],
+ eps_clip: float = 0.2,
+ dual_clip: Optional[float] = None,
+ value_clip: bool = False,
+ advantage_normalization: bool = True,
+ recompute_advantage: bool = False,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(actor, critic, optim, dist_fn, **kwargs)
+ self._eps_clip = eps_clip
+ assert dual_clip is None or dual_clip > 1.0, \
+ "Dual-clip PPO parameter should greater than 1.0."
+ self._dual_clip = dual_clip
+ self._value_clip = value_clip
+ self._norm_adv = advantage_normalization
+ self._recompute_adv = recompute_advantage
+ self._actor_critic: ActorCritic
+
+ def process_fn(
+ self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
+ ) -> Batch:
+ if self._recompute_adv:
+ # buffer input `buffer` and `indices` to be used in `learn()`.
+ self._buffer, self._indices = buffer, indices
+ batch = self._compute_returns(batch, buffer, indices)
+ batch.act = to_torch_as(batch.act, batch.v_s)
+ with torch.no_grad():
+ batch.logp_old = self(batch).dist.log_prob(batch.act)
+ return batch
+
+ def learn( # type: ignore
+ self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
+ ) -> Dict[str, List[float]]:
+ losses, clip_losses, vf_losses, ent_losses = [], [], [], []
+ for step in range(repeat):
+ if self._recompute_adv and step > 0:
+ batch = self._compute_returns(batch, self._buffer, self._indices)
+ for minibatch in batch.split(batch_size, merge_last=True):
+ # calculate loss for actor
+ dist = self(minibatch).dist
+ if self._norm_adv:
+ mean, std = minibatch.adv.mean(), minibatch.adv.std()
+ minibatch.adv = (minibatch.adv -
+ mean) / (std + self._eps) # per-batch norm
+ ratio = (dist.log_prob(minibatch.act) -
+ minibatch.logp_old).exp().float()
+ ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
+ surr1 = ratio * minibatch.adv
+ surr2 = ratio.clamp(
+ 1.0 - self._eps_clip, 1.0 + self._eps_clip
+ ) * minibatch.adv
+ if self._dual_clip:
+ clip1 = torch.min(surr1, surr2)
+ clip2 = torch.max(clip1, self._dual_clip * minibatch.adv)
+ clip_loss = -torch.where(minibatch.adv < 0, clip2, clip1).mean()
+ else:
+ clip_loss = -torch.min(surr1, surr2).mean()
+ # calculate loss for critic
+ value = self.critic(minibatch.obs).flatten()
+ if self._value_clip:
+ v_clip = minibatch.v_s + \
+ (value - minibatch.v_s).clamp(-self._eps_clip, self._eps_clip)
+ vf1 = (minibatch.returns - value).pow(2)
+ vf2 = (minibatch.returns - v_clip).pow(2)
+ vf_loss = torch.max(vf1, vf2).mean()
+ else:
+ vf_loss = (minibatch.returns - value).pow(2).mean()
+ # calculate regularization and overall loss
+ ent_loss = dist.entropy().mean()
+ loss = clip_loss + self._weight_vf * vf_loss \
+ - self._weight_ent * ent_loss
+ self.optim.zero_grad()
+ loss.backward()
+ if self._grad_norm: # clip large gradient
+ nn.utils.clip_grad_norm_(
+ self._actor_critic.parameters(), max_norm=self._grad_norm
+ )
+ self.optim.step()
+ clip_losses.append(clip_loss.item())
+ vf_losses.append(vf_loss.item())
+ ent_losses.append(ent_loss.item())
+ losses.append(loss.item())
+
+ return {
+ "loss": losses,
+ "loss/clip": clip_losses,
+ "loss/vf": vf_losses,
+ "loss/ent": ent_losses,
+ }
diff --git a/tianshou/policy/modelfree/qrdqn.py b/tianshou/policy/modelfree/qrdqn.py
new file mode 100644
index 0000000..39dde3d
--- /dev/null
+++ b/tianshou/policy/modelfree/qrdqn.py
@@ -0,0 +1,97 @@
+import warnings
+from typing import Any, Dict, Optional
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from tianshou.data import Batch, ReplayBuffer
+from tianshou.policy import DQNPolicy
+
+
+class QRDQNPolicy(DQNPolicy):
+ """Implementation of Quantile Regression Deep Q-Network. arXiv:1710.10044.
+
+ :param torch.nn.Module model: a model following the rules in
+ :class:`~tianshou.policy.BasePolicy`. (s -> logits)
+ :param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
+ :param float discount_factor: in [0, 1].
+ :param int num_quantiles: the number of quantile midpoints in the inverse
+ cumulative distribution function of the value. Default to 200.
+ :param int estimation_step: the number of steps to look ahead. Default to 1.
+ :param int target_update_freq: the target network update frequency (0 if
+ you do not use the target network).
+ :param bool reward_normalization: normalize the reward to Normal(0, 1).
+ Default to False.
+ :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
+ optimizer in each policy.update(). Default to None (no lr_scheduler).
+
+ .. seealso::
+
+ Please refer to :class:`~tianshou.policy.DQNPolicy` for more detailed
+ explanation.
+ """
+
+ def __init__(
+ self,
+ model: torch.nn.Module,
+ optim: torch.optim.Optimizer,
+ discount_factor: float = 0.99,
+ num_quantiles: int = 200,
+ estimation_step: int = 1,
+ target_update_freq: int = 0,
+ reward_normalization: bool = False,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(
+ model, optim, discount_factor, estimation_step, target_update_freq,
+ reward_normalization, **kwargs
+ )
+ assert num_quantiles > 1, "num_quantiles should be greater than 1"
+ self._num_quantiles = num_quantiles
+ tau = torch.linspace(0, 1, self._num_quantiles + 1)
+ self.tau_hat = torch.nn.Parameter(
+ ((tau[:-1] + tau[1:]) / 2).view(1, -1, 1), requires_grad=False
+ )
+ warnings.filterwarnings("ignore", message="Using a target size")
+
+ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
+ batch = buffer[indices] # batch.obs_next: s_{t+n}
+ if self._target:
+ act = self(batch, input="obs_next").act
+ next_dist = self(batch, model="model_old", input="obs_next").logits
+ else:
+ next_batch = self(batch, input="obs_next")
+ act = next_batch.act
+ next_dist = next_batch.logits
+ next_dist = next_dist[np.arange(len(act)), act, :]
+ return next_dist # shape: [bsz, num_quantiles]
+
+ def compute_q_value(
+ self, logits: torch.Tensor, mask: Optional[np.ndarray]
+ ) -> torch.Tensor:
+ return super().compute_q_value(logits.mean(2), mask)
+
+ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
+ if self._target and self._iter % self._freq == 0:
+ self.sync_weight()
+ self.optim.zero_grad()
+ weight = batch.pop("weight", 1.0)
+ curr_dist = self(batch).logits
+ act = batch.act
+ curr_dist = curr_dist[np.arange(len(act)), act, :].unsqueeze(2)
+ target_dist = batch.returns.unsqueeze(1)
+ # calculate each element's difference between curr_dist and target_dist
+ dist_diff = F.smooth_l1_loss(target_dist, curr_dist, reduction="none")
+ huber_loss = (
+ dist_diff *
+ (self.tau_hat - (target_dist - curr_dist).detach().le(0.).float()).abs()
+ ).sum(-1).mean(1)
+ loss = (huber_loss * weight).mean()
+ # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/
+ # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130
+ batch.weight = dist_diff.detach().abs().sum(-1).mean(1) # prio-buffer
+ loss.backward()
+ self.optim.step()
+ self._iter += 1
+ return {"loss": loss.item()}
diff --git a/tianshou/policy/modelfree/rainbow.py b/tianshou/policy/modelfree/rainbow.py
new file mode 100644
index 0000000..773abdd
--- /dev/null
+++ b/tianshou/policy/modelfree/rainbow.py
@@ -0,0 +1,39 @@
+from typing import Any, Dict
+
+from tianshou.data import Batch
+from tianshou.policy import C51Policy
+from tianshou.utils.net.discrete import sample_noise
+
+
+class RainbowPolicy(C51Policy):
+ """Implementation of Rainbow DQN. arXiv:1710.02298.
+
+ :param torch.nn.Module model: a model following the rules in
+ :class:`~tianshou.policy.BasePolicy`. (s -> logits)
+ :param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
+ :param float discount_factor: in [0, 1].
+ :param int num_atoms: the number of atoms in the support set of the
+ value distribution. Default to 51.
+ :param float v_min: the value of the smallest atom in the support set.
+ Default to -10.0.
+ :param float v_max: the value of the largest atom in the support set.
+ Default to 10.0.
+ :param int estimation_step: the number of steps to look ahead. Default to 1.
+ :param int target_update_freq: the target network update frequency (0 if
+ you do not use the target network). Default to 0.
+ :param bool reward_normalization: normalize the reward to Normal(0, 1).
+ Default to False.
+ :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
+ optimizer in each policy.update(). Default to None (no lr_scheduler).
+
+ .. seealso::
+
+ Please refer to :class:`~tianshou.policy.C51Policy` for more detailed
+ explanation.
+ """
+
+ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
+ sample_noise(self.model)
+ if self._target and sample_noise(self.model_old):
+ self.model_old.train() # so that NoisyLinear takes effect
+ return super().learn(batch, **kwargs)
diff --git a/tianshou/policy/modelfree/redq.py b/tianshou/policy/modelfree/redq.py
new file mode 100644
index 0000000..fdf27d4
--- /dev/null
+++ b/tianshou/policy/modelfree/redq.py
@@ -0,0 +1,200 @@
+from copy import deepcopy
+from typing import Any, Dict, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from torch.distributions import Independent, Normal
+
+from tianshou.data import Batch, ReplayBuffer
+from tianshou.exploration import BaseNoise
+from tianshou.policy import DDPGPolicy
+
+
+class REDQPolicy(DDPGPolicy):
+ """Implementation of REDQ. arXiv:2101.05982.
+
+ :param torch.nn.Module actor: the actor network following the rules in
+ :class:`~tianshou.policy.BasePolicy`. (s -> logits)
+ :param torch.optim.Optimizer actor_optim: the optimizer for actor network.
+ :param torch.nn.Module critics: critic ensemble networks.
+ :param torch.optim.Optimizer critics_optim: the optimizer for the critic networks.
+ :param int ensemble_size: Number of sub-networks in the critic ensemble.
+ Default to 10.
+ :param int subset_size: Number of networks in the subset. Default to 2.
+ :param float tau: param for soft update of the target network. Default to 0.005.
+ :param float gamma: discount factor, in [0, 1]. Default to 0.99.
+ :param (float, torch.Tensor, torch.optim.Optimizer) or float alpha: entropy
+ regularization coefficient. Default to 0.2.
+ If a tuple (target_entropy, log_alpha, alpha_optim) is provided, then
+ alpha is automatically tuned.
+ :param bool reward_normalization: normalize the reward to Normal(0, 1).
+ Default to False.
+ :param int actor_delay: Number of critic updates before an actor update.
+ Default to 20.
+ :param BaseNoise exploration_noise: add a noise to action for exploration.
+ Default to None. This is useful when solving hard-exploration problem.
+ :param bool deterministic_eval: whether to use deterministic action (mean
+ of Gaussian policy) instead of stochastic action sampled by the policy.
+ Default to True.
+ :param str target_mode: methods to integrate critic values in the subset,
+ currently support minimum and average. Default to min.
+ :param bool action_scaling: whether to map actions from range [-1, 1] to range
+ [action_spaces.low, action_spaces.high]. Default to True.
+ :param str action_bound_method: method to bound action to range [-1, 1], can be
+ either "clip" (for simply clipping the action) or empty string for no bounding.
+ Default to "clip".
+ :param Optional[gym.Space] action_space: env's action space, mandatory if you want
+ to use option "action_scaling" or "action_bound_method". Default to None.
+
+ .. seealso::
+
+ Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
+ explanation.
+ """
+
+ def __init__(
+ self,
+ actor: torch.nn.Module,
+ actor_optim: torch.optim.Optimizer,
+ critics: torch.nn.Module,
+ critics_optim: torch.optim.Optimizer,
+ ensemble_size: int = 10,
+ subset_size: int = 2,
+ tau: float = 0.005,
+ gamma: float = 0.99,
+ alpha: Union[float, Tuple[float, torch.Tensor, torch.optim.Optimizer]] = 0.2,
+ reward_normalization: bool = False,
+ estimation_step: int = 1,
+ actor_delay: int = 20,
+ exploration_noise: Optional[BaseNoise] = None,
+ deterministic_eval: bool = True,
+ target_mode: str = "min",
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(
+ None, None, None, None, tau, gamma, exploration_noise,
+ reward_normalization, estimation_step, **kwargs
+ )
+ self.actor, self.actor_optim = actor, actor_optim
+ self.critics, self.critics_old = critics, deepcopy(critics)
+ self.critics_old.eval()
+ self.critics_optim = critics_optim
+ assert 0 < subset_size <= ensemble_size, \
+ "Invalid choice of ensemble size or subset size."
+ self.ensemble_size = ensemble_size
+ self.subset_size = subset_size
+
+ self._is_auto_alpha = False
+ self._alpha: Union[float, torch.Tensor]
+ if isinstance(alpha, tuple):
+ self._is_auto_alpha = True
+ self._target_entropy, self._log_alpha, self._alpha_optim = alpha
+ assert alpha[1].shape == torch.Size([1]) and alpha[1].requires_grad
+ self._alpha = self._log_alpha.detach().exp()
+ else:
+ self._alpha = alpha
+
+ if target_mode in ("min", "mean"):
+ self.target_mode = target_mode
+ else:
+ raise ValueError("Unsupported mode of Q target computing.")
+
+ self.critic_gradient_step = 0
+ self.actor_delay = actor_delay
+ self._deterministic_eval = deterministic_eval
+ self.__eps = np.finfo(np.float32).eps.item()
+
+ def train(self, mode: bool = True) -> "REDQPolicy":
+ self.training = mode
+ self.actor.train(mode)
+ self.critics.train(mode)
+ return self
+
+ def sync_weight(self) -> None:
+ for o, n in zip(self.critics_old.parameters(), self.critics.parameters()):
+ o.data.copy_(o.data * (1.0 - self.tau) + n.data * self.tau)
+
+ def forward( # type: ignore
+ self,
+ batch: Batch,
+ state: Optional[Union[dict, Batch, np.ndarray]] = None,
+ input: str = "obs",
+ **kwargs: Any,
+ ) -> Batch:
+ obs = batch[input]
+ logits, h = self.actor(obs, state=state, info=batch.info)
+ assert isinstance(logits, tuple)
+ dist = Independent(Normal(*logits), 1)
+ if self._deterministic_eval and not self.training:
+ act = logits[0]
+ else:
+ act = dist.rsample()
+ log_prob = dist.log_prob(act).unsqueeze(-1)
+ # apply correction for Tanh squashing when computing logprob from Gaussian
+ # You can check out the original SAC paper (arXiv 1801.01290): Eq 21.
+ # in appendix C to get some understanding of this equation.
+ squashed_action = torch.tanh(act)
+ log_prob = log_prob - torch.log((1 - squashed_action.pow(2)) +
+ self.__eps).sum(-1, keepdim=True)
+ return Batch(
+ logits=logits, act=squashed_action, state=h, dist=dist, log_prob=log_prob
+ )
+
+ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
+ batch = buffer[indices] # batch.obs: s_{t+n}
+ obs_next_result = self(batch, input="obs_next")
+ a_ = obs_next_result.act
+ sample_ensemble_idx = np.random.choice(
+ self.ensemble_size, self.subset_size, replace=False
+ )
+ qs = self.critics_old(batch.obs_next, a_)[sample_ensemble_idx, ...]
+ if self.target_mode == "min":
+ target_q, _ = torch.min(qs, dim=0)
+ elif self.target_mode == "mean":
+ target_q = torch.mean(qs, dim=0)
+ target_q -= self._alpha * obs_next_result.log_prob
+
+ return target_q
+
+ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
+ # critic ensemble
+ weight = getattr(batch, "weight", 1.0)
+ current_qs = self.critics(batch.obs, batch.act).flatten(1)
+ target_q = batch.returns.flatten()
+ td = current_qs - target_q
+ critic_loss = (td.pow(2) * weight).mean()
+ self.critics_optim.zero_grad()
+ critic_loss.backward()
+ self.critics_optim.step()
+ batch.weight = torch.mean(td, dim=0) # prio-buffer
+ self.critic_gradient_step += 1
+
+ # actor
+ if self.critic_gradient_step % self.actor_delay == 0:
+ obs_result = self(batch)
+ a = obs_result.act
+ current_qa = self.critics(batch.obs, a).mean(dim=0).flatten()
+ actor_loss = (self._alpha * obs_result.log_prob.flatten() -
+ current_qa).mean()
+ self.actor_optim.zero_grad()
+ actor_loss.backward()
+ self.actor_optim.step()
+
+ if self._is_auto_alpha:
+ log_prob = obs_result.log_prob.detach() + self._target_entropy
+ alpha_loss = -(self._log_alpha * log_prob).mean()
+ self._alpha_optim.zero_grad()
+ alpha_loss.backward()
+ self._alpha_optim.step()
+ self._alpha = self._log_alpha.detach().exp()
+
+ self.sync_weight()
+
+ result = {"loss/critics": critic_loss.item()}
+ if self.critic_gradient_step % self.actor_delay == 0:
+ result["loss/actor"] = actor_loss.item(),
+ if self._is_auto_alpha:
+ result["loss/alpha"] = alpha_loss.item()
+ result["alpha"] = self._alpha.item() # type: ignore
+
+ return result
diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py
new file mode 100644
index 0000000..5f17427
--- /dev/null
+++ b/tianshou/policy/modelfree/sac.py
@@ -0,0 +1,190 @@
+from copy import deepcopy
+from typing import Any, Dict, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from torch.distributions import Independent, Normal
+
+from tianshou.data import Batch, ReplayBuffer
+from tianshou.exploration import BaseNoise
+from tianshou.policy import DDPGPolicy
+
+
+class SACPolicy(DDPGPolicy):
+ """Implementation of Soft Actor-Critic. arXiv:1812.05905.
+
+ :param torch.nn.Module actor: the actor network following the rules in
+ :class:`~tianshou.policy.BasePolicy`. (s -> logits)
+ :param torch.optim.Optimizer actor_optim: the optimizer for actor network.
+ :param torch.nn.Module critic1: the first critic network. (s, a -> Q(s, a))
+ :param torch.optim.Optimizer critic1_optim: the optimizer for the first
+ critic network.
+ :param torch.nn.Module critic2: the second critic network. (s, a -> Q(s, a))
+ :param torch.optim.Optimizer critic2_optim: the optimizer for the second
+ critic network.
+ :param float tau: param for soft update of the target network. Default to 0.005.
+ :param float gamma: discount factor, in [0, 1]. Default to 0.99.
+ :param (float, torch.Tensor, torch.optim.Optimizer) or float alpha: entropy
+ regularization coefficient. Default to 0.2.
+ If a tuple (target_entropy, log_alpha, alpha_optim) is provided, then
+ alpha is automatically tuned.
+ :param bool reward_normalization: normalize the reward to Normal(0, 1).
+ Default to False.
+ :param BaseNoise exploration_noise: add a noise to action for exploration.
+ Default to None. This is useful when solving hard-exploration problem.
+ :param bool deterministic_eval: whether to use deterministic action (mean
+ of Gaussian policy) instead of stochastic action sampled by the policy.
+ Default to True.
+ :param bool action_scaling: whether to map actions from range [-1, 1] to range
+ [action_spaces.low, action_spaces.high]. Default to True.
+ :param str action_bound_method: method to bound action to range [-1, 1], can be
+ either "clip" (for simply clipping the action) or empty string for no bounding.
+ Default to "clip".
+ :param Optional[gym.Space] action_space: env's action space, mandatory if you want
+ to use option "action_scaling" or "action_bound_method". Default to None.
+ :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
+ optimizer in each policy.update(). Default to None (no lr_scheduler).
+
+ .. seealso::
+
+ Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
+ explanation.
+ """
+
+ def __init__(
+ self,
+ actor: torch.nn.Module,
+ actor_optim: torch.optim.Optimizer,
+ critic1: torch.nn.Module,
+ critic1_optim: torch.optim.Optimizer,
+ critic2: torch.nn.Module,
+ critic2_optim: torch.optim.Optimizer,
+ tau: float = 0.005,
+ gamma: float = 0.99,
+ alpha: Union[float, Tuple[float, torch.Tensor, torch.optim.Optimizer]] = 0.2,
+ reward_normalization: bool = False,
+ estimation_step: int = 1,
+ exploration_noise: Optional[BaseNoise] = None,
+ deterministic_eval: bool = True,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(
+ None, None, None, None, tau, gamma, exploration_noise,
+ reward_normalization, estimation_step, **kwargs
+ )
+ self.actor, self.actor_optim = actor, actor_optim
+ self.critic1, self.critic1_old = critic1, deepcopy(critic1)
+ self.critic1_old.eval()
+ self.critic1_optim = critic1_optim
+ self.critic2, self.critic2_old = critic2, deepcopy(critic2)
+ self.critic2_old.eval()
+ self.critic2_optim = critic2_optim
+
+ self._is_auto_alpha = False
+ self._alpha: Union[float, torch.Tensor]
+ if isinstance(alpha, tuple):
+ self._is_auto_alpha = True
+ self._target_entropy, self._log_alpha, self._alpha_optim = alpha
+ assert alpha[1].shape == torch.Size([1]) and alpha[1].requires_grad
+ self._alpha = self._log_alpha.detach().exp()
+ else:
+ self._alpha = alpha
+
+ self._deterministic_eval = deterministic_eval
+ self.__eps = np.finfo(np.float32).eps.item()
+
+ def train(self, mode: bool = True) -> "SACPolicy":
+ self.training = mode
+ self.actor.train(mode)
+ self.critic1.train(mode)
+ self.critic2.train(mode)
+ return self
+
+ def sync_weight(self) -> None:
+ self.soft_update(self.critic1_old, self.critic1, self.tau)
+ self.soft_update(self.critic2_old, self.critic2, self.tau)
+
+ def forward( # type: ignore
+ self,
+ batch: Batch,
+ state: Optional[Union[dict, Batch, np.ndarray]] = None,
+ input: str = "obs",
+ **kwargs: Any,
+ ) -> Batch:
+ obs = batch[input]
+ logits, hidden = self.actor(obs, state=state, info=batch.info)
+ assert isinstance(logits, tuple)
+ dist = Independent(Normal(*logits), 1)
+ if self._deterministic_eval and not self.training:
+ act = logits[0]
+ else:
+ act = dist.rsample()
+ log_prob = dist.log_prob(act).unsqueeze(-1)
+ # apply correction for Tanh squashing when computing logprob from Gaussian
+ # You can check out the original SAC paper (arXiv 1801.01290): Eq 21.
+ # in appendix C to get some understanding of this equation.
+ squashed_action = torch.tanh(act)
+ log_prob = log_prob - torch.log((1 - squashed_action.pow(2)) +
+ self.__eps).sum(-1, keepdim=True)
+ return Batch(
+ logits=logits,
+ act=squashed_action,
+ state=hidden,
+ dist=dist,
+ log_prob=log_prob
+ )
+
+ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
+ batch = buffer[indices] # batch.obs: s_{t+n}
+ obs_next_result = self(batch, input="obs_next")
+ act_ = obs_next_result.act
+ target_q = torch.min(
+ self.critic1_old(batch.obs_next, act_),
+ self.critic2_old(batch.obs_next, act_),
+ ) - self._alpha * obs_next_result.log_prob
+ return target_q
+
+ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
+ # critic 1&2
+ td1, critic1_loss = self._mse_optimizer(
+ batch, self.critic1, self.critic1_optim
+ )
+ td2, critic2_loss = self._mse_optimizer(
+ batch, self.critic2, self.critic2_optim
+ )
+ batch.weight = (td1 + td2) / 2.0 # prio-buffer
+
+ # actor
+ obs_result = self(batch)
+ act = obs_result.act
+ current_q1a = self.critic1(batch.obs, act).flatten()
+ current_q2a = self.critic2(batch.obs, act).flatten()
+ actor_loss = (
+ self._alpha * obs_result.log_prob.flatten() -
+ torch.min(current_q1a, current_q2a)
+ ).mean()
+ self.actor_optim.zero_grad()
+ actor_loss.backward()
+ self.actor_optim.step()
+
+ if self._is_auto_alpha:
+ log_prob = obs_result.log_prob.detach() + self._target_entropy
+ # please take a look at issue #258 if you'd like to change this line
+ alpha_loss = -(self._log_alpha * log_prob).mean()
+ self._alpha_optim.zero_grad()
+ alpha_loss.backward()
+ self._alpha_optim.step()
+ self._alpha = self._log_alpha.detach().exp()
+
+ self.sync_weight()
+
+ result = {
+ "loss/actor": actor_loss.item(),
+ "loss/critic1": critic1_loss.item(),
+ "loss/critic2": critic2_loss.item(),
+ }
+ if self._is_auto_alpha:
+ result["loss/alpha"] = alpha_loss.item()
+ result["alpha"] = self._alpha.item() # type: ignore
+
+ return result
diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py
new file mode 100644
index 0000000..003e099
--- /dev/null
+++ b/tianshou/policy/modelfree/td3.py
@@ -0,0 +1,134 @@
+from copy import deepcopy
+from typing import Any, Dict, Optional
+
+import numpy as np
+import torch
+
+from tianshou.data import Batch, ReplayBuffer
+from tianshou.exploration import BaseNoise, GaussianNoise
+from tianshou.policy import DDPGPolicy
+
+
+class TD3Policy(DDPGPolicy):
+ """Implementation of TD3, arXiv:1802.09477.
+
+ :param torch.nn.Module actor: the actor network following the rules in
+ :class:`~tianshou.policy.BasePolicy`. (s -> logits)
+ :param torch.optim.Optimizer actor_optim: the optimizer for actor network.
+ :param torch.nn.Module critic1: the first critic network. (s, a -> Q(s, a))
+ :param torch.optim.Optimizer critic1_optim: the optimizer for the first
+ critic network.
+ :param torch.nn.Module critic2: the second critic network. (s, a -> Q(s, a))
+ :param torch.optim.Optimizer critic2_optim: the optimizer for the second
+ critic network.
+ :param float tau: param for soft update of the target network. Default to 0.005.
+ :param float gamma: discount factor, in [0, 1]. Default to 0.99.
+ :param float exploration_noise: the exploration noise, add to the action.
+ Default to ``GaussianNoise(sigma=0.1)``
+ :param float policy_noise: the noise used in updating policy network.
+ Default to 0.2.
+ :param int update_actor_freq: the update frequency of actor network.
+ Default to 2.
+ :param float noise_clip: the clipping range used in updating policy network.
+ Default to 0.5.
+ :param bool reward_normalization: normalize the reward to Normal(0, 1).
+ Default to False.
+ :param bool action_scaling: whether to map actions from range [-1, 1] to range
+ [action_spaces.low, action_spaces.high]. Default to True.
+ :param str action_bound_method: method to bound action to range [-1, 1], can be
+ either "clip" (for simply clipping the action) or empty string for no bounding.
+ Default to "clip".
+ :param Optional[gym.Space] action_space: env's action space, mandatory if you want
+ to use option "action_scaling" or "action_bound_method". Default to None.
+ :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
+ optimizer in each policy.update(). Default to None (no lr_scheduler).
+
+ .. seealso::
+
+ Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
+ explanation.
+ """
+
+ def __init__(
+ self,
+ actor: torch.nn.Module,
+ actor_optim: torch.optim.Optimizer,
+ critic1: torch.nn.Module,
+ critic1_optim: torch.optim.Optimizer,
+ critic2: torch.nn.Module,
+ critic2_optim: torch.optim.Optimizer,
+ tau: float = 0.005,
+ gamma: float = 0.99,
+ exploration_noise: Optional[BaseNoise] = GaussianNoise(sigma=0.1),
+ policy_noise: float = 0.2,
+ update_actor_freq: int = 2,
+ noise_clip: float = 0.5,
+ reward_normalization: bool = False,
+ estimation_step: int = 1,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(
+ actor, actor_optim, None, None, tau, gamma, exploration_noise,
+ reward_normalization, estimation_step, **kwargs
+ )
+ self.critic1, self.critic1_old = critic1, deepcopy(critic1)
+ self.critic1_old.eval()
+ self.critic1_optim = critic1_optim
+ self.critic2, self.critic2_old = critic2, deepcopy(critic2)
+ self.critic2_old.eval()
+ self.critic2_optim = critic2_optim
+ self._policy_noise = policy_noise
+ self._freq = update_actor_freq
+ self._noise_clip = noise_clip
+ self._cnt = 0
+ self._last = 0
+
+ def train(self, mode: bool = True) -> "TD3Policy":
+ self.training = mode
+ self.actor.train(mode)
+ self.critic1.train(mode)
+ self.critic2.train(mode)
+ return self
+
+ def sync_weight(self) -> None:
+ self.soft_update(self.critic1_old, self.critic1, self.tau)
+ self.soft_update(self.critic2_old, self.critic2, self.tau)
+ self.soft_update(self.actor_old, self.actor, self.tau)
+
+ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
+ batch = buffer[indices] # batch.obs: s_{t+n}
+ act_ = self(batch, model="actor_old", input="obs_next").act
+ noise = torch.randn(size=act_.shape, device=act_.device) * self._policy_noise
+ if self._noise_clip > 0.0:
+ noise = noise.clamp(-self._noise_clip, self._noise_clip)
+ act_ += noise
+ target_q = torch.min(
+ self.critic1_old(batch.obs_next, act_),
+ self.critic2_old(batch.obs_next, act_),
+ )
+ return target_q
+
+ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
+ # critic 1&2
+ td1, critic1_loss = self._mse_optimizer(
+ batch, self.critic1, self.critic1_optim
+ )
+ td2, critic2_loss = self._mse_optimizer(
+ batch, self.critic2, self.critic2_optim
+ )
+ batch.weight = (td1 + td2) / 2.0 # prio-buffer
+
+ # actor
+ if self._cnt % self._freq == 0:
+ actor_loss = -self.critic1(batch.obs, self(batch, eps=0.0).act).mean()
+ self.actor_optim.zero_grad()
+ actor_loss.backward()
+ self._last = actor_loss.item()
+ self.actor_optim.step()
+ self.sync_weight()
+ self._cnt += 1
+ return {
+ "loss/actor": self._last,
+ "loss/critic1": critic1_loss.item(),
+ "loss/critic2": critic2_loss.item(),
+ }
diff --git a/tianshou/policy/modelfree/trpo.py b/tianshou/policy/modelfree/trpo.py
new file mode 100644
index 0000000..00688af
--- /dev/null
+++ b/tianshou/policy/modelfree/trpo.py
@@ -0,0 +1,154 @@
+import warnings
+from typing import Any, Dict, List, Type
+
+import torch
+import torch.nn.functional as F
+from torch.distributions import kl_divergence
+
+from tianshou.data import Batch
+from tianshou.policy import NPGPolicy
+
+
+class TRPOPolicy(NPGPolicy):
+ """Implementation of Trust Region Policy Optimization. arXiv:1502.05477.
+
+ :param torch.nn.Module actor: the actor network following the rules in
+ :class:`~tianshou.policy.BasePolicy`. (s -> logits)
+ :param torch.nn.Module critic: the critic network. (s -> V(s))
+ :param torch.optim.Optimizer optim: the optimizer for actor and critic network.
+ :param dist_fn: distribution class for computing the action.
+ :type dist_fn: Type[torch.distributions.Distribution]
+ :param bool advantage_normalization: whether to do per mini-batch advantage
+ normalization. Default to True.
+ :param int optim_critic_iters: Number of times to optimize critic network per
+ update. Default to 5.
+ :param int max_kl: max kl-divergence used to constrain each actor network update.
+ Default to 0.01.
+ :param float backtrack_coeff: Coefficient to be multiplied by step size when
+ constraints are not met. Default to 0.8.
+ :param int max_backtracks: Max number of backtracking times in linesearch. Default
+ to 10.
+ :param float gae_lambda: in [0, 1], param for Generalized Advantage Estimation.
+ Default to 0.95.
+ :param bool reward_normalization: normalize estimated values to have std close to
+ 1. Default to False.
+ :param int max_batchsize: the maximum size of the batch when computing GAE,
+ depends on the size of available memory and the memory cost of the
+ model; should be as large as possible within the memory constraint.
+ Default to 256.
+ :param bool action_scaling: whether to map actions from range [-1, 1] to range
+ [action_spaces.low, action_spaces.high]. Default to True.
+ :param str action_bound_method: method to bound action to range [-1, 1], can be
+ either "clip" (for simply clipping the action), "tanh" (for applying tanh
+ squashing) for now, or empty string for no bounding. Default to "clip".
+ :param Optional[gym.Space] action_space: env's action space, mandatory if you want
+ to use option "action_scaling" or "action_bound_method". Default to None.
+ :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
+ optimizer in each policy.update(). Default to None (no lr_scheduler).
+ :param bool deterministic_eval: whether to use deterministic action instead of
+ stochastic action sampled by the policy. Default to False.
+ """
+
+ def __init__(
+ self,
+ actor: torch.nn.Module,
+ critic: torch.nn.Module,
+ optim: torch.optim.Optimizer,
+ dist_fn: Type[torch.distributions.Distribution],
+ max_kl: float = 0.01,
+ backtrack_coeff: float = 0.8,
+ max_backtracks: int = 10,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(actor, critic, optim, dist_fn, **kwargs)
+ self._max_backtracks = max_backtracks
+ self._delta = max_kl
+ self._backtrack_coeff = backtrack_coeff
+ self._optim_critic_iters: int
+
+ def learn( # type: ignore
+ self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
+ ) -> Dict[str, List[float]]:
+ actor_losses, vf_losses, step_sizes, kls = [], [], [], []
+ for _ in range(repeat):
+ for minibatch in batch.split(batch_size, merge_last=True):
+ # optimize actor
+ # direction: calculate villia gradient
+ dist = self(minibatch).dist # TODO could come from batch
+ ratio = (dist.log_prob(minibatch.act) -
+ minibatch.logp_old).exp().float()
+ ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
+ actor_loss = -(ratio * minibatch.adv).mean()
+ flat_grads = self._get_flat_grad(
+ actor_loss, self.actor, retain_graph=True
+ ).detach()
+
+ # direction: calculate natural gradient
+ with torch.no_grad():
+ old_dist = self(minibatch).dist
+
+ kl = kl_divergence(old_dist, dist).mean()
+ # calculate first order gradient of kl with respect to theta
+ flat_kl_grad = self._get_flat_grad(kl, self.actor, create_graph=True)
+ search_direction = -self._conjugate_gradients(
+ flat_grads, flat_kl_grad, nsteps=10
+ )
+
+ # stepsize: calculate max stepsize constrained by kl bound
+ step_size = torch.sqrt(
+ 2 * self._delta /
+ (search_direction *
+ self._MVP(search_direction, flat_kl_grad)).sum(0, keepdim=True)
+ )
+
+ # stepsize: linesearch stepsize
+ with torch.no_grad():
+ flat_params = torch.cat(
+ [param.data.view(-1) for param in self.actor.parameters()]
+ )
+ for i in range(self._max_backtracks):
+ new_flat_params = flat_params + step_size * search_direction
+ self._set_from_flat_params(self.actor, new_flat_params)
+ # calculate kl and if in bound, loss actually down
+ new_dist = self(minibatch).dist
+ new_dratio = (
+ new_dist.log_prob(minibatch.act) - minibatch.logp_old
+ ).exp().float()
+ new_dratio = new_dratio.reshape(new_dratio.size(0),
+ -1).transpose(0, 1)
+ new_actor_loss = -(new_dratio * minibatch.adv).mean()
+ kl = kl_divergence(old_dist, new_dist).mean()
+
+ if kl < self._delta and new_actor_loss < actor_loss:
+ if i > 0:
+ warnings.warn(f"Backtracking to step {i}.")
+ break
+ elif i < self._max_backtracks - 1:
+ step_size = step_size * self._backtrack_coeff
+ else:
+ self._set_from_flat_params(self.actor, new_flat_params)
+ step_size = torch.tensor([0.0])
+ warnings.warn(
+ "Line search failed! It seems hyperparamters"
+ " are poor and need to be changed."
+ )
+
+ # optimize citirc
+ for _ in range(self._optim_critic_iters):
+ value = self.critic(minibatch.obs).flatten()
+ vf_loss = F.mse_loss(minibatch.returns, value)
+ self.optim.zero_grad()
+ vf_loss.backward()
+ self.optim.step()
+
+ actor_losses.append(actor_loss.item())
+ vf_losses.append(vf_loss.item())
+ step_sizes.append(step_size.item())
+ kls.append(kl.item())
+
+ return {
+ "loss/actor": actor_losses,
+ "loss/vf": vf_losses,
+ "step_size": step_sizes,
+ "kl": kls,
+ }
diff --git a/tianshou/policy/multiagent/__init__.py b/tianshou/policy/multiagent/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tianshou/policy/multiagent/mapolicy.py b/tianshou/policy/multiagent/mapolicy.py
new file mode 100644
index 0000000..bcbba97
--- /dev/null
+++ b/tianshou/policy/multiagent/mapolicy.py
@@ -0,0 +1,197 @@
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+
+from tianshou.data import Batch, ReplayBuffer
+from tianshou.policy import BasePolicy
+
+try:
+ from tianshou.env.pettingzoo_env import PettingZooEnv
+except ImportError:
+ PettingZooEnv = None # type: ignore
+
+
+class MultiAgentPolicyManager(BasePolicy):
+ """Multi-agent policy manager for MARL.
+
+ This multi-agent policy manager accepts a list of
+ :class:`~tianshou.policy.BasePolicy`. It dispatches the batch data to each
+ of these policies when the "forward" is called. The same as "process_fn"
+ and "learn": it splits the data and feeds them to each policy. A figure in
+ :ref:`marl_example` can help you better understand this procedure.
+ """
+
+ def __init__(
+ self, policies: List[BasePolicy], env: PettingZooEnv, **kwargs: Any
+ ) -> None:
+ super().__init__(action_space=env.action_space, **kwargs)
+ assert (
+ len(policies) == len(env.agents)
+ ), "One policy must be assigned for each agent."
+
+ self.agent_idx = env.agent_idx
+ for i, policy in enumerate(policies):
+ # agent_id 0 is reserved for the environment proxy
+ # (this MultiAgentPolicyManager)
+ policy.set_agent_id(env.agents[i])
+
+ self.policies = dict(zip(env.agents, policies))
+
+ def replace_policy(self, policy: BasePolicy, agent_id: int) -> None:
+ """Replace the "agent_id"th policy in this manager."""
+ policy.set_agent_id(agent_id)
+ self.policies[agent_id] = policy
+
+ def process_fn(
+ self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
+ ) -> Batch:
+ """Dispatch batch data from obs.agent_id to every policy's process_fn.
+
+ Save original multi-dimensional rew in "save_rew", set rew to the
+ reward of each agent during their "process_fn", and restore the
+ original reward afterwards.
+ """
+ results = {}
+ # reward can be empty Batch (after initial reset) or nparray.
+ has_rew = isinstance(buffer.rew, np.ndarray)
+ if has_rew: # save the original reward in save_rew
+ # Since we do not override buffer.__setattr__, here we use _meta to
+ # change buffer.rew, otherwise buffer.rew = Batch() has no effect.
+ save_rew, buffer._meta.rew = buffer.rew, Batch()
+ for agent, policy in self.policies.items():
+ agent_index = np.nonzero(batch.obs.agent_id == agent)[0]
+ if len(agent_index) == 0:
+ results[agent] = Batch()
+ continue
+ tmp_batch, tmp_indice = batch[agent_index], indice[agent_index]
+ if has_rew:
+ tmp_batch.rew = tmp_batch.rew[:, self.agent_idx[agent]]
+ buffer._meta.rew = save_rew[:, self.agent_idx[agent]]
+ if not hasattr(tmp_batch.obs, "mask"):
+ if hasattr(tmp_batch.obs, 'obs'):
+ tmp_batch.obs = tmp_batch.obs.obs
+ if hasattr(tmp_batch.obs_next, 'obs'):
+ tmp_batch.obs_next = tmp_batch.obs_next.obs
+ results[agent] = policy.process_fn(tmp_batch, buffer, tmp_indice)
+ if has_rew: # restore from save_rew
+ buffer._meta.rew = save_rew
+ return Batch(results)
+
+ def exploration_noise(self, act: Union[np.ndarray, Batch],
+ batch: Batch) -> Union[np.ndarray, Batch]:
+ """Add exploration noise from sub-policy onto act."""
+ for agent_id, policy in self.policies.items():
+ agent_index = np.nonzero(batch.obs.agent_id == agent_id)[0]
+ if len(agent_index) == 0:
+ continue
+ act[agent_index] = policy.exploration_noise(
+ act[agent_index], batch[agent_index]
+ )
+ return act
+
+ def forward( # type: ignore
+ self,
+ batch: Batch,
+ state: Optional[Union[dict, Batch]] = None,
+ **kwargs: Any,
+ ) -> Batch:
+ """Dispatch batch data from obs.agent_id to every policy's forward.
+
+ :param state: if None, it means all agents have no state. If not
+ None, it should contain keys of "agent_1", "agent_2", ...
+
+ :return: a Batch with the following contents:
+
+ ::
+
+ {
+ "act": actions corresponding to the input
+ "state": {
+ "agent_1": output state of agent_1's policy for the state
+ "agent_2": xxx
+ ...
+ "agent_n": xxx}
+ "out": {
+ "agent_1": output of agent_1's policy for the input
+ "agent_2": xxx
+ ...
+ "agent_n": xxx}
+ }
+ """
+ results: List[Tuple[bool, np.ndarray, Batch, Union[np.ndarray, Batch],
+ Batch]] = []
+ for agent_id, policy in self.policies.items():
+ # This part of code is difficult to understand.
+ # Let's follow an example with two agents
+ # batch.obs.agent_id is [1, 2, 1, 2, 1, 2] (with batch_size == 6)
+ # each agent plays for three transitions
+ # agent_index for agent 1 is [0, 2, 4]
+ # agent_index for agent 2 is [1, 3, 5]
+ # we separate the transition of each agent according to agent_id
+ agent_index = np.nonzero(batch.obs.agent_id == agent_id)[0]
+ if len(agent_index) == 0:
+ # (has_data, agent_index, out, act, state)
+ results.append((False, np.array([-1]), Batch(), Batch(), Batch()))
+ continue
+ tmp_batch = batch[agent_index]
+ if isinstance(tmp_batch.rew, np.ndarray):
+ # reward can be empty Batch (after initial reset) or nparray.
+ tmp_batch.rew = tmp_batch.rew[:, self.agent_idx[agent_id]]
+ if not hasattr(tmp_batch.obs, "mask"):
+ if hasattr(tmp_batch.obs, 'obs'):
+ tmp_batch.obs = tmp_batch.obs.obs
+ if hasattr(tmp_batch.obs_next, 'obs'):
+ tmp_batch.obs_next = tmp_batch.obs_next.obs
+ out = policy(
+ batch=tmp_batch,
+ state=None if state is None else state[agent_id],
+ **kwargs
+ )
+ act = out.act
+ each_state = out.state \
+ if (hasattr(out, "state") and out.state is not None) \
+ else Batch()
+ results.append((True, agent_index, out, act, each_state))
+ holder = Batch.cat(
+ [
+ {
+ "act": act
+ } for (has_data, agent_index, out, act, each_state) in results
+ if has_data
+ ]
+ )
+ state_dict, out_dict = {}, {}
+ for (agent_id, _), (has_data, agent_index, out, act,
+ state) in zip(self.policies.items(), results):
+ if has_data:
+ holder.act[agent_index] = act
+ state_dict[agent_id] = state
+ out_dict[agent_id] = out
+ holder["out"] = out_dict
+ holder["state"] = state_dict
+ return holder
+
+ def learn(self, batch: Batch,
+ **kwargs: Any) -> Dict[str, Union[float, List[float]]]:
+ """Dispatch the data to all policies for learning.
+
+ :return: a dict with the following contents:
+
+ ::
+
+ {
+ "agent_1/item1": item 1 of agent_1's policy.learn output
+ "agent_1/item2": item 2 of agent_1's policy.learn output
+ "agent_2/xxx": xxx
+ ...
+ "agent_n/xxx": xxx
+ }
+ """
+ results = {}
+ for agent_id, policy in self.policies.items():
+ data = batch[agent_id]
+ if not data.is_empty():
+ out = policy.learn(batch=data, **kwargs)
+ for k, v in out.items():
+ results[agent_id + "/" + k] = v
+ return results
diff --git a/tianshou/policy/random.py b/tianshou/policy/random.py
new file mode 100644
index 0000000..dfb7956
--- /dev/null
+++ b/tianshou/policy/random.py
@@ -0,0 +1,43 @@
+from typing import Any, Dict, Optional, Union
+
+import numpy as np
+
+from tianshou.data import Batch
+from tianshou.policy import BasePolicy
+
+
+class RandomPolicy(BasePolicy):
+ """A random agent used in multi-agent learning.
+
+ It randomly chooses an action from the legal action.
+ """
+
+ def forward(
+ self,
+ batch: Batch,
+ state: Optional[Union[dict, Batch, np.ndarray]] = None,
+ **kwargs: Any,
+ ) -> Batch:
+ """Compute the random action over the given batch data.
+
+ The input should contain a mask in batch.obs, with "True" to be
+ available and "False" to be unavailable. For example,
+ ``batch.obs.mask == np.array([[False, True, False]])`` means with batch
+ size 1, action "1" is available but action "0" and "2" are unavailable.
+
+ :return: A :class:`~tianshou.data.Batch` with "act" key, containing
+ the random action.
+
+ .. seealso::
+
+ Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
+ more detailed explanation.
+ """
+ mask = batch.obs.mask
+ logits = np.random.rand(*mask.shape)
+ logits[~mask] = -np.inf
+ return Batch(act=logits.argmax(axis=-1))
+
+ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
+ """Since a random agent learns nothing, it returns an empty dict."""
+ return {}
diff --git a/tianshou/trainer/__init__.py b/tianshou/trainer/__init__.py
new file mode 100644
index 0000000..8f1361b
--- /dev/null
+++ b/tianshou/trainer/__init__.py
@@ -0,0 +1,34 @@
+"""Trainer package."""
+
+from tianshou.trainer.base import BaseTrainer
+from tianshou.trainer.offline import (
+ OfflineTrainer,
+ offline_trainer,
+ offline_trainer_iter,
+)
+from tianshou.trainer.offpolicy import (
+ OffpolicyTrainer,
+ offpolicy_trainer,
+ offpolicy_trainer_iter,
+)
+from tianshou.trainer.onpolicy import (
+ OnpolicyTrainer,
+ onpolicy_trainer,
+ onpolicy_trainer_iter,
+)
+from tianshou.trainer.utils import gather_info, test_episode
+
+__all__ = [
+ "BaseTrainer",
+ "offpolicy_trainer",
+ "offpolicy_trainer_iter",
+ "OffpolicyTrainer",
+ "onpolicy_trainer",
+ "onpolicy_trainer_iter",
+ "OnpolicyTrainer",
+ "offline_trainer",
+ "offline_trainer_iter",
+ "OfflineTrainer",
+ "test_episode",
+ "gather_info",
+]
diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py
new file mode 100644
index 0000000..b3f624e
--- /dev/null
+++ b/tianshou/trainer/base.py
@@ -0,0 +1,460 @@
+import time
+from abc import ABC, abstractmethod
+from collections import defaultdict, deque
+from typing import Any, Callable, DefaultDict, Dict, Optional, Tuple, Union
+
+import numpy as np
+import tqdm
+
+from tianshou.data import AsyncCollector, Collector, ReplayBuffer
+from tianshou.policy import BasePolicy
+from tianshou.trainer.utils import gather_info, test_episode
+from tianshou.utils import (
+ BaseLogger,
+ DummyTqdm,
+ LazyLogger,
+ MovAvg,
+ deprecation,
+ tqdm_config,
+)
+
+
+class BaseTrainer(ABC):
+ """An iterator base class for trainers procedure.
+
+ Returns an iterator that yields a 3-tuple (epoch, stats, info) of train results
+ on every epoch.
+
+ :param learning_type str: type of learning iterator, available choices are
+ "offpolicy", "onpolicy" and "offline".
+ :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
+ :param Collector train_collector: the collector used for training.
+ :param Collector test_collector: the collector used for testing. If it's None,
+ then no testing will be performed.
+ :param int max_epoch: the maximum number of epochs for training. The training
+ process might be finished before reaching ``max_epoch`` if ``stop_fn``
+ is set.
+ :param int step_per_epoch: the number of transitions collected per epoch.
+ :param int repeat_per_collect: the number of repeat time for policy learning,
+ for example, set it to 2 means the policy needs to learn each given batch
+ data twice.
+ :param int episode_per_test: the number of episodes for one policy evaluation.
+ :param int batch_size: the batch size of sample data, which is going to feed in
+ the policy network.
+ :param int step_per_collect: the number of transitions the collector would
+ collect before the network update, i.e., trainer will collect
+ "step_per_collect" transitions and do some policy network update repeatedly
+ in each epoch.
+ :param int episode_per_collect: the number of episodes the collector would
+ collect before the network update, i.e., trainer will collect
+ "episode_per_collect" episodes and do some policy network update repeatedly
+ in each epoch.
+ :param function train_fn: a hook called at the beginning of training in each
+ epoch. It can be used to perform custom additional operations, with the
+ signature ``f(num_epoch: int, step_idx: int) -> None``.
+ :param function test_fn: a hook called at the beginning of testing in each
+ epoch. It can be used to perform custom additional operations, with the
+ signature ``f(num_epoch: int, step_idx: int) -> None``.
+ :param function save_best_fn: a hook called when the undiscounted average mean
+ reward in evaluation phase gets better, with the signature
+ ``f(policy: BasePolicy) -> None``. It was ``save_fn`` previously.
+ :param function save_checkpoint_fn: a function to save training process and
+ return the saved checkpoint path, with the signature ``f(epoch: int,
+ env_step: int, gradient_step: int) -> str``; you can save whatever you want.
+ :param bool resume_from_log: resume env_step/gradient_step and other metadata
+ from existing tensorboard log. Default to False.
+ :param function stop_fn: a function with signature ``f(mean_rewards: float) ->
+ bool``, receives the average undiscounted returns of the testing result,
+ returns a boolean which indicates whether reaching the goal.
+ :param function reward_metric: a function with signature
+ ``f(rewards: np.ndarray with shape (num_episode, agent_num)) -> np.ndarray
+ with shape (num_episode,)``, used in multi-agent RL. We need to return a
+ single scalar for each episode's result to monitor training in the
+ multi-agent RL setting. This function specifies what is the desired metric,
+ e.g., the reward of agent 1 or the average reward over all agents.
+ :param BaseLogger logger: A logger that logs statistics during
+ training/testing/updating. Default to a logger that doesn't log anything.
+ :param bool verbose: whether to print the information. Default to True.
+ :param bool show_progress: whether to display a progress bar when training.
+ Default to True.
+ :param bool test_in_train: whether to test in the training phase.
+ Default to True.
+ """
+
+ @staticmethod
+ def gen_doc(learning_type: str) -> str:
+ """Document string for subclass trainer."""
+ step_means = f'The "step" in {learning_type} trainer means '
+ if learning_type != "offline":
+ step_means += "an environment step (a.k.a. transition)."
+ else: # offline
+ step_means += "a gradient step."
+
+ trainer_name = learning_type.capitalize() + "Trainer"
+
+ return f"""An iterator class for {learning_type} trainer procedure.
+
+ Returns an iterator that yields a 3-tuple (epoch, stats, info) of
+ train results on every epoch.
+
+ {step_means}
+
+ Example usage:
+
+ ::
+
+ trainer = {trainer_name}(...)
+ for epoch, epoch_stat, info in trainer:
+ print("Epoch:", epoch)
+ print(epoch_stat)
+ print(info)
+ do_something_with_policy()
+ query_something_about_policy()
+ make_a_plot_with(epoch_stat)
+ display(info)
+
+ - epoch int: the epoch number
+ - epoch_stat dict: a large collection of metrics of the current epoch
+ - info dict: result returned from :func:`~tianshou.trainer.gather_info`
+
+ You can even iterate on several trainers at the same time:
+
+ ::
+
+ trainer1 = {trainer_name}(...)
+ trainer2 = {trainer_name}(...)
+ for result1, result2, ... in zip(trainer1, trainer2, ...):
+ compare_results(result1, result2, ...)
+ """
+
+ def __init__(
+ self,
+ learning_type: str,
+ policy: BasePolicy,
+ max_epoch: int,
+ batch_size: int,
+ train_collector: Optional[Collector] = None,
+ test_collector: Optional[Collector] = None,
+ buffer: Optional[ReplayBuffer] = None,
+ step_per_epoch: Optional[int] = None,
+ repeat_per_collect: Optional[int] = None,
+ episode_per_test: Optional[int] = None,
+ update_per_step: Union[int, float] = 1,
+ update_per_epoch: Optional[int] = None,
+ step_per_collect: Optional[int] = None,
+ episode_per_collect: Optional[int] = None,
+ train_fn: Optional[Callable[[int, int], None]] = None,
+ test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
+ stop_fn: Optional[Callable[[float], bool]] = None,
+ save_best_fn: Optional[Callable[[BasePolicy], None]] = None,
+ save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None,
+ resume_from_log: bool = False,
+ reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
+ logger: BaseLogger = LazyLogger(),
+ verbose: bool = True,
+ show_progress: bool = True,
+ test_in_train: bool = True,
+ save_fn: Optional[Callable[[BasePolicy], None]] = None,
+ ):
+ if save_fn:
+ deprecation(
+ "save_fn in trainer is marked as deprecated and will be "
+ "removed in the future. Please use save_best_fn instead."
+ )
+ assert save_best_fn is None
+ save_best_fn = save_fn
+
+ self.policy = policy
+ self.buffer = buffer
+
+ self.train_collector = train_collector
+ self.test_collector = test_collector
+
+ self.logger = logger
+ self.start_time = time.time()
+ self.stat: DefaultDict[str, MovAvg] = defaultdict(MovAvg)
+ self.best_reward = 0.0
+ self.best_reward_std = 0.0
+ self.start_epoch = 0
+ self.gradient_step = 0
+ self.env_step = 0
+ self.max_epoch = max_epoch
+ self.step_per_epoch = step_per_epoch
+
+ # either on of these two
+ self.step_per_collect = step_per_collect
+ self.episode_per_collect = episode_per_collect
+
+ self.update_per_step = update_per_step
+ self.repeat_per_collect = repeat_per_collect
+
+ self.episode_per_test = episode_per_test
+
+ self.batch_size = batch_size
+
+ self.train_fn = train_fn
+ self.test_fn = test_fn
+ self.stop_fn = stop_fn
+ self.save_best_fn = save_best_fn
+ self.save_checkpoint_fn = save_checkpoint_fn
+
+ self.reward_metric = reward_metric
+ self.verbose = verbose
+ self.show_progress = show_progress
+ self.test_in_train = test_in_train
+ self.resume_from_log = resume_from_log
+
+ self.is_run = False
+ self.last_rew, self.last_len = 0.0, 0
+
+ self.epoch = self.start_epoch
+ self.best_epoch = self.start_epoch
+ self.stop_fn_flag = False
+ self.iter_num = 0
+
+ def reset(self) -> None:
+ """Initialize or reset the instance to yield a new iterator from zero."""
+ self.is_run = False
+ self.env_step = 0
+ if self.resume_from_log:
+ self.start_epoch, self.env_step, self.gradient_step = \
+ self.logger.restore_data()
+
+ self.last_rew, self.last_len = 0.0, 0
+ self.start_time = time.time()
+ if self.train_collector is not None:
+ self.train_collector.reset_stat()
+
+ if self.train_collector.policy != self.policy:
+ self.test_in_train = False
+ elif self.test_collector is None:
+ self.test_in_train = False
+
+ if self.test_collector is not None:
+ assert self.episode_per_test is not None
+ assert not isinstance(self.test_collector, AsyncCollector) # Issue 700
+ self.test_collector.reset_stat()
+ test_result = test_episode(
+ self.policy, self.test_collector, self.test_fn, self.start_epoch,
+ self.episode_per_test, self.logger, self.env_step, self.reward_metric
+ )
+ self.best_epoch = self.start_epoch
+ self.best_reward, self.best_reward_std = \
+ test_result["rew"], test_result["rew_std"]
+ if self.save_best_fn:
+ self.save_best_fn(self.policy)
+
+ self.epoch = self.start_epoch
+ self.stop_fn_flag = False
+ self.iter_num = 0
+
+ def __iter__(self): # type: ignore
+ self.reset()
+ return self
+
+ def __next__(self) -> Union[None, Tuple[int, Dict[str, Any], Dict[str, Any]]]:
+ """Perform one epoch (both train and eval)."""
+ self.epoch += 1
+ self.iter_num += 1
+
+ if self.iter_num > 1:
+
+ # iterator exhaustion check
+ if self.epoch > self.max_epoch:
+ raise StopIteration
+
+ # exit flag 1, when stop_fn succeeds in train_step or test_step
+ if self.stop_fn_flag:
+ raise StopIteration
+
+ # set policy in train mode
+ self.policy.train()
+
+ epoch_stat: Dict[str, Any] = dict()
+
+ if self.show_progress:
+ progress = tqdm.tqdm
+ else:
+ progress = DummyTqdm
+
+ # perform n step_per_epoch
+ with progress(
+ total=self.step_per_epoch, desc=f"Epoch #{self.epoch}", **tqdm_config
+ ) as t:
+ while t.n < t.total and not self.stop_fn_flag:
+ data: Dict[str, Any] = dict()
+ result: Dict[str, Any] = dict()
+
+ if self.train_collector is not None:
+ data, result, self.stop_fn_flag = self.train_step()
+ t.update(result["n/st"]) # 默认参数n=1,每update一次,进度条+n
+ if self.stop_fn_flag:
+ t.set_postfix(**data)
+ break
+ else:
+ assert self.buffer, "No train_collector or buffer specified"
+ result["n/ep"] = len(self.buffer)
+ result["n/st"] = int(self.gradient_step)
+ t.update()
+
+ self.policy_update_fn(data, result)
+ t.set_postfix(**data) # 设置进度条右边的信息
+
+ if t.n <= t.total and not self.stop_fn_flag:
+ t.update()
+
+ # for offline RL
+ if self.train_collector is None:
+ self.env_step = self.gradient_step * self.batch_size
+
+ if not self.stop_fn_flag:
+ self.logger.save_data(
+ self.epoch, self.env_step, self.gradient_step, self.save_checkpoint_fn
+ )
+ # test
+ if self.test_collector is not None:
+ test_stat, self.stop_fn_flag = self.test_step()
+ if not self.is_run:
+ epoch_stat.update(test_stat)
+
+ if not self.is_run:
+ epoch_stat.update({k: v.get() for k, v in self.stat.items()})
+ epoch_stat["gradient_step"] = self.gradient_step
+ epoch_stat.update(
+ {
+ "env_step": self.env_step,
+ "rew": self.last_rew,
+ "len": int(self.last_len),
+ "n/ep": int(result["n/ep"]),
+ "n/st": int(result["n/st"]),
+ }
+ )
+ info = gather_info(
+ self.start_time, self.train_collector, self.test_collector,
+ self.best_reward, self.best_reward_std
+ )
+ return self.epoch, epoch_stat, info
+ else:
+ return None
+
+ def test_step(self) -> Tuple[Dict[str, Any], bool]:
+ """Perform one testing step."""
+ assert self.episode_per_test is not None
+ assert self.test_collector is not None
+ stop_fn_flag = False
+ test_result = test_episode(
+ self.policy, self.test_collector, self.test_fn, self.epoch,
+ self.episode_per_test, self.logger, self.env_step, self.reward_metric
+ )
+ rew, rew_std = test_result["rew"], test_result["rew_std"]
+ if self.best_epoch < 0 or self.best_reward < rew:
+ self.best_epoch = self.epoch
+ self.best_reward = float(rew)
+ self.best_reward_std = rew_std
+ if self.save_best_fn:
+ self.save_best_fn(self.policy)
+ if self.verbose:
+ print(
+ f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f},"
+ f" best_reward: {self.best_reward:.6f} ± "
+ f"{self.best_reward_std:.6f} in #{self.best_epoch}",
+ flush=True
+ )
+ if not self.is_run:
+ test_stat = {
+ "test_reward": rew,
+ "test_reward_std": rew_std,
+ "best_reward": self.best_reward,
+ "best_reward_std": self.best_reward_std,
+ "best_epoch": self.best_epoch
+ }
+ else:
+ test_stat = {}
+ if self.stop_fn and self.stop_fn(self.best_reward):
+ stop_fn_flag = True
+
+ return test_stat, stop_fn_flag
+
+ def train_step(self) -> Tuple[Dict[str, Any], Dict[str, Any], bool]:
+ """Perform one training step."""
+ assert self.episode_per_test is not None
+ assert self.train_collector is not None
+ stop_fn_flag = False
+
+ if self.train_fn:
+ self.train_fn(self.epoch, self.env_step)
+
+ result = self.train_collector.collect(
+ n_step=self.step_per_collect, n_episode=self.episode_per_collect
+ )
+
+ if result["n/ep"] > 0 and self.reward_metric:
+ rew = self.reward_metric(result["rews"])
+ result.update(rews=rew, rew=rew.mean(), rew_std=rew.std())
+ self.env_step += int(result["n/st"])
+ self.logger.log_train_data(result, self.env_step)
+ self.logger.write(
+ "train/env_step",
+ self.env_step,
+ {"train/lr": self.policy.optim.param_groups[0]["lr"]}
+ )
+ self.last_rew = result["rew"] if result["n/ep"] > 0 else self.last_rew
+ self.last_len = result["len"] if result["n/ep"] > 0 else self.last_len
+ data = {
+ "env_step": str(self.env_step), # 现在环境训练的步数
+ "rew": f"{self.last_rew:.3f}", # 每个episode的平均reward
+ # "len": str(int(self.last_len)), # 每个episode的平均长度
+ "len": f"{self.last_len:.1f}", # 每个episode的平均长度
+ "n/ep": str(int(result["n/ep"])), # 每个epoch的episode数
+ "n/st": str(int(result["n/st"])), # 每个epoch的step数
+ }
+
+ if result["n/ep"] > 0:
+ if self.test_in_train and self.stop_fn and self.stop_fn(result["rew"]):
+ assert self.test_collector is not None
+ test_result = test_episode(
+ self.policy, self.test_collector, self.test_fn, self.epoch,
+ self.episode_per_test, self.logger, self.env_step
+ )
+ if self.stop_fn(test_result["rew"]):
+ stop_fn_flag = True
+ self.best_reward = test_result["rew"]
+ self.best_reward_std = test_result["rew_std"]
+ else:
+ self.policy.train()
+
+ return data, result, stop_fn_flag
+
+ def log_update_data(self, data: Dict[str, Any], losses: Dict[str, Any]) -> None:
+ """Log losses to current logger."""
+ for k in losses.keys():
+ self.stat[k].add(losses[k])
+ losses[k] = self.stat[k].get()
+ data[k] = f"{losses[k]:.3f}"
+ self.logger.log_update_data(losses, self.gradient_step)
+
+ @abstractmethod
+ def policy_update_fn(self, data: Dict[str, Any], result: Dict[str, Any]) -> None:
+ """Policy update function for different trainer implementation.
+
+ :param data: information in progress bar.
+ :param result: collector's return value.
+ """
+
+ def run(self) -> Dict[str, Union[float, str]]:
+ """Consume iterator.
+
+ See itertools - recipes. Use functions that consume iterators at C speed
+ (feed the entire iterator into a zero-length deque).
+ """
+ try:
+ self.is_run = True
+ deque(self, maxlen=0) # feed the entire iterator into a zero-length deque
+ info = gather_info(
+ self.start_time, self.train_collector, self.test_collector,
+ self.best_reward, self.best_reward_std
+ )
+ finally:
+ self.is_run = False
+
+ return info
diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py
new file mode 100644
index 0000000..33028ca
--- /dev/null
+++ b/tianshou/trainer/offline.py
@@ -0,0 +1,121 @@
+from typing import Any, Callable, Dict, Optional, Union
+
+import numpy as np
+
+from tianshou.data import Collector, ReplayBuffer
+from tianshou.policy import BasePolicy
+from tianshou.trainer.base import BaseTrainer
+from tianshou.utils import BaseLogger, LazyLogger
+
+
+class OfflineTrainer(BaseTrainer):
+ """Create an iterator class for offline training procedure.
+
+ :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
+ :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class.
+ This buffer must be populated with experiences for offline RL.
+ :param Collector test_collector: the collector used for testing. If it's None,
+ then no testing will be performed.
+ :param int max_epoch: the maximum number of epochs for training. The training
+ process might be finished before reaching ``max_epoch`` if ``stop_fn`` is
+ set.
+ :param int update_per_epoch: the number of policy network updates, so-called
+ gradient steps, per epoch.
+ :param episode_per_test: the number of episodes for one policy evaluation.
+ :param int batch_size: the batch size of sample data, which is going to feed in
+ the policy network.
+ :param function test_fn: a hook called at the beginning of testing in each
+ epoch.
+ It can be used to perform custom additional operations, with the signature
+ ``f(num_epoch: int, step_idx: int) -> None``.
+ :param function save_best_fn: a hook called when the undiscounted average mean
+ reward in evaluation phase gets better, with the signature
+ ``f(policy: BasePolicy) -> None``. It was ``save_fn`` previously.
+ :param function save_checkpoint_fn: a function to save training process and
+ return the saved checkpoint path, with the signature ``f(epoch: int,
+ env_step: int, gradient_step: int) -> str``; you can save whatever you want.
+ Because offline-RL doesn't have env_step, the env_step is always 0 here.
+ :param bool resume_from_log: resume gradient_step and other metadata from
+ existing tensorboard log. Default to False.
+ :param function stop_fn: a function with signature ``f(mean_rewards: float) ->
+ bool``, receives the average undiscounted returns of the testing result,
+ returns a boolean which indicates whether reaching the goal.
+ :param function reward_metric: a function with signature ``f(rewards:
+ np.ndarray with shape (num_episode, agent_num)) -> np.ndarray with shape
+ (num_episode,)``, used in multi-agent RL. We need to return a single scalar
+ for each episode's result to monitor training in the multi-agent RL
+ setting. This function specifies what is the desired metric, e.g., the
+ reward of agent 1 or the average reward over all agents.
+ :param BaseLogger logger: A logger that logs statistics during
+ updating/testing. Default to a logger that doesn't log anything.
+ :param bool verbose: whether to print the information. Default to True.
+ :param bool show_progress: whether to display a progress bar when training.
+ Default to True.
+ """
+
+ __doc__ = BaseTrainer.gen_doc("offline") + "\n".join(__doc__.split("\n")[1:])
+
+ def __init__(
+ self,
+ policy: BasePolicy,
+ buffer: ReplayBuffer,
+ test_collector: Optional[Collector],
+ max_epoch: int,
+ update_per_epoch: int,
+ episode_per_test: int,
+ batch_size: int,
+ test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
+ stop_fn: Optional[Callable[[float], bool]] = None,
+ save_best_fn: Optional[Callable[[BasePolicy], None]] = None,
+ save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None,
+ resume_from_log: bool = False,
+ reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
+ logger: BaseLogger = LazyLogger(),
+ verbose: bool = True,
+ show_progress: bool = True,
+ **kwargs: Any,
+ ):
+ super().__init__(
+ learning_type="offline",
+ policy=policy,
+ buffer=buffer,
+ test_collector=test_collector,
+ max_epoch=max_epoch,
+ update_per_epoch=update_per_epoch,
+ step_per_epoch=update_per_epoch,
+ episode_per_test=episode_per_test,
+ batch_size=batch_size,
+ test_fn=test_fn,
+ stop_fn=stop_fn,
+ save_best_fn=save_best_fn,
+ save_checkpoint_fn=save_checkpoint_fn,
+ resume_from_log=resume_from_log,
+ reward_metric=reward_metric,
+ logger=logger,
+ verbose=verbose,
+ show_progress=show_progress,
+ **kwargs,
+ )
+
+ def policy_update_fn(
+ self, data: Dict[str, Any], result: Optional[Dict[str, Any]] = None
+ ) -> None:
+ """Perform one off-line policy update."""
+ assert self.buffer
+ self.gradient_step += 1
+ losses = self.policy.update(self.batch_size, self.buffer)
+ data.update({"gradient_step": str(self.gradient_step)})
+ self.log_update_data(data, losses)
+
+
+def offline_trainer(*args, **kwargs) -> Dict[str, Union[float, str]]: # type: ignore
+ """Wrapper for offline_trainer run method.
+
+ It is identical to ``OfflineTrainer(...).run()``.
+
+ :return: See :func:`~tianshou.trainer.gather_info`.
+ """
+ return OfflineTrainer(*args, **kwargs).run()
+
+
+offline_trainer_iter = OfflineTrainer
diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py
new file mode 100644
index 0000000..6282399
--- /dev/null
+++ b/tianshou/trainer/offpolicy.py
@@ -0,0 +1,136 @@
+from typing import Any, Callable, Dict, Optional, Union
+
+import numpy as np
+
+from tianshou.data import Collector
+from tianshou.policy import BasePolicy
+from tianshou.trainer.base import BaseTrainer
+from tianshou.utils import BaseLogger, LazyLogger
+
+
+class OffpolicyTrainer(BaseTrainer):
+ """Create an iterator wrapper for off-policy training procedure.
+
+ :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
+ :param Collector train_collector: the collector used for training.
+ :param Collector test_collector: the collector used for testing. If it's None,
+ then no testing will be performed.
+ :param int max_epoch: the maximum number of epochs for training. The training
+ process might be finished before reaching ``max_epoch`` if ``stop_fn`` is
+ set.
+ :param int step_per_epoch: the number of transitions collected per epoch.
+ :param int step_per_collect: the number of transitions the collector would
+ collect before the network update, i.e., trainer will collect
+ "step_per_collect" transitions and do some policy network update repeatedly
+ in each epoch.
+ :param episode_per_test: the number of episodes for one policy evaluation.
+ :param int batch_size: the batch size of sample data, which is going to feed in
+ the policy network.
+ :param int/float update_per_step: the number of times the policy network would
+ be updated per transition after (step_per_collect) transitions are
+ collected, e.g., if update_per_step set to 0.3, and step_per_collect is 256
+ , policy will be updated round(256 * 0.3 = 76.8) = 77 times after 256
+ transitions are collected by the collector. Default to 1.
+ :param function train_fn: a hook called at the beginning of training in each
+ epoch. It can be used to perform custom additional operations, with the
+ signature ``f(num_epoch: int, step_idx: int) -> None``.
+ :param function test_fn: a hook called at the beginning of testing in each
+ epoch. It can be used to perform custom additional operations, with the
+ signature ``f(num_epoch: int, step_idx: int) -> None``.
+ :param function save_best_fn: a hook called when the undiscounted average mean
+ reward in evaluation phase gets better, with the signature
+ ``f(policy: BasePolicy) -> None``. It was ``save_fn`` previously.
+ :param function save_checkpoint_fn: a function to save training process and
+ return the saved checkpoint path, with the signature ``f(epoch: int,
+ env_step: int, gradient_step: int) -> str``; you can save whatever you want.
+ :param bool resume_from_log: resume env_step/gradient_step and other metadata
+ from existing tensorboard log. Default to False.
+ :param function stop_fn: a function with signature ``f(mean_rewards: float) ->
+ bool``, receives the average undiscounted returns of the testing result,
+ returns a boolean which indicates whether reaching the goal.
+ :param function reward_metric: a function with signature
+ ``f(rewards: np.ndarray with shape (num_episode, agent_num)) ->
+ np.ndarray with shape (num_episode,)``, used in multi-agent RL. We need to
+ return a single scalar for each episode's result to monitor training in the
+ multi-agent RL setting. This function specifies what is the desired metric,
+ e.g., the reward of agent 1 or the average reward over all agents.
+ :param BaseLogger logger: A logger that logs statistics during
+ training/testing/updating. Default to a logger that doesn't log anything.
+ :param bool verbose: whether to print the information. Default to True.
+ :param bool show_progress: whether to display a progress bar when training.
+ Default to True.
+ :param bool test_in_train: whether to test in the training phase.
+ Default to True.
+ """
+
+ __doc__ = BaseTrainer.gen_doc("offpolicy") + "\n".join(__doc__.split("\n")[1:])
+
+ def __init__(
+ self,
+ policy: BasePolicy,
+ train_collector: Collector,
+ test_collector: Optional[Collector],
+ max_epoch: int,
+ step_per_epoch: int,
+ step_per_collect: int,
+ episode_per_test: int,
+ batch_size: int,
+ update_per_step: Union[int, float] = 1,
+ train_fn: Optional[Callable[[int, int], None]] = None,
+ test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
+ stop_fn: Optional[Callable[[float], bool]] = None,
+ save_best_fn: Optional[Callable[[BasePolicy], None]] = None,
+ save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None,
+ resume_from_log: bool = False,
+ reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
+ logger: BaseLogger = LazyLogger(),
+ verbose: bool = True,
+ show_progress: bool = True,
+ test_in_train: bool = True,
+ **kwargs: Any,
+ ):
+ super().__init__(
+ learning_type="offpolicy",
+ policy=policy,
+ train_collector=train_collector,
+ test_collector=test_collector,
+ max_epoch=max_epoch,
+ step_per_epoch=step_per_epoch,
+ step_per_collect=step_per_collect,
+ episode_per_test=episode_per_test,
+ batch_size=batch_size,
+ update_per_step=update_per_step,
+ train_fn=train_fn,
+ test_fn=test_fn,
+ stop_fn=stop_fn,
+ save_best_fn=save_best_fn,
+ save_checkpoint_fn=save_checkpoint_fn,
+ resume_from_log=resume_from_log,
+ reward_metric=reward_metric,
+ logger=logger,
+ verbose=verbose,
+ show_progress=show_progress,
+ test_in_train=test_in_train,
+ **kwargs,
+ )
+
+ def policy_update_fn(self, data: Dict[str, Any], result: Dict[str, Any]) -> None:
+ """Perform off-policy updates."""
+ assert self.train_collector is not None
+ for _ in range(round(self.update_per_step * result["n/st"])):
+ self.gradient_step += 1
+ losses = self.policy.update(self.batch_size, self.train_collector.buffer)
+ self.log_update_data(data, losses)
+
+
+def offpolicy_trainer(*args, **kwargs) -> Dict[str, Union[float, str]]: # type: ignore
+ """Wrapper for OffPolicyTrainer run method.
+
+ It is identical to ``OffpolicyTrainer(...).run()``.
+
+ :return: See :func:`~tianshou.trainer.gather_info`.
+ """
+ return OffpolicyTrainer(*args, **kwargs).run()
+
+
+offpolicy_trainer_iter = OffpolicyTrainer
diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py
new file mode 100644
index 0000000..641bda3
--- /dev/null
+++ b/tianshou/trainer/onpolicy.py
@@ -0,0 +1,153 @@
+from typing import Any, Callable, Dict, Optional, Union
+
+import numpy as np
+
+from tianshou.data import Collector
+from tianshou.policy import BasePolicy
+from tianshou.trainer.base import BaseTrainer
+from tianshou.utils import BaseLogger, LazyLogger
+
+
+class OnpolicyTrainer(BaseTrainer):
+ """Create an iterator wrapper for on-policy training procedure.
+
+ :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
+ :param Collector train_collector: the collector used for training.
+ :param Collector test_collector: the collector used for testing. If it's None,
+ then no testing will be performed.
+ :param int max_epoch: the maximum number of epochs for training. The training
+ process might be finished before reaching ``max_epoch`` if ``stop_fn`` is
+ set.
+ :param int step_per_epoch: the number of transitions collected per epoch.
+ :param int repeat_per_collect: the number of repeat time for policy learning,
+ for example, set it to 2 means the policy needs to learn each given batch
+ data twice.
+ :param int episode_per_test: the number of episodes for one policy evaluation.
+ :param int batch_size: the batch size of sample data, which is going to feed in
+ the policy network.
+ :param int step_per_collect: the number of transitions the collector would
+ collect before the network update, i.e., trainer will collect
+ "step_per_collect" transitions and do some policy network update repeatedly
+ in each epoch.
+ :param int episode_per_collect: the number of episodes the collector would
+ collect before the network update, i.e., trainer will collect
+ "episode_per_collect" episodes and do some policy network update repeatedly
+ in each epoch.
+ :param function train_fn: a hook called at the beginning of training in each
+ epoch. It can be used to perform custom additional operations, with the
+ signature ``f(num_epoch: int, step_idx: int) -> None``.
+ :param function test_fn: a hook called at the beginning of testing in each
+ epoch. It can be used to perform custom additional operations, with the
+ signature ``f(num_epoch: int, step_idx: int) -> None``.
+ :param function save_best_fn: a hook called when the undiscounted average mean
+ reward in evaluation phase gets better, with the signature
+ ``f(policy: BasePolicy) -> None``. It was ``save_fn`` previously.
+ :param function save_checkpoint_fn: a function to save training process and
+ return the saved checkpoint path, with the signature ``f(epoch: int,
+ env_step: int, gradient_step: int) -> str``; you can save whatever you want.
+ :param bool resume_from_log: resume env_step/gradient_step and other metadata
+ from existing tensorboard log. Default to False.
+ :param function stop_fn: a function with signature ``f(mean_rewards: float) ->
+ bool``, receives the average undiscounted returns of the testing result,
+ returns a boolean which indicates whether reaching the goal.
+ :param function reward_metric: a function with signature
+ ``f(rewards: np.ndarray with shape (num_episode, agent_num)) ->
+ np.ndarray with shape (num_episode,)``, used in multi-agent RL.
+ We need to return a single scalar for each episode's result to monitor
+ training in the multi-agent RL setting. This function specifies what is the
+ desired metric, e.g., the reward of agent 1 or the average reward over
+ all agents.
+ :param BaseLogger logger: A logger that logs statistics during
+ training/testing/updating. Default to a logger that doesn't log anything.
+ :param bool verbose: whether to print the information. Default to True.
+ :param bool show_progress: whether to display a progress bar when training.
+ Default to True.
+ :param bool test_in_train: whether to test in the training phase. Default to
+ True.
+
+ .. note::
+
+ Only either one of step_per_collect and episode_per_collect can be specified.
+ """
+
+ __doc__ = BaseTrainer.gen_doc("onpolicy") + "\n".join(__doc__.split("\n")[1:])
+
+ def __init__(
+ self,
+ policy: BasePolicy,
+ train_collector: Collector,
+ test_collector: Optional[Collector],
+ max_epoch: int,
+ step_per_epoch: int,
+ repeat_per_collect: int,
+ episode_per_test: int,
+ batch_size: int,
+ step_per_collect: Optional[int] = None,
+ episode_per_collect: Optional[int] = None,
+ train_fn: Optional[Callable[[int, int], None]] = None,
+ test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
+ stop_fn: Optional[Callable[[float], bool]] = None,
+ save_best_fn: Optional[Callable[[BasePolicy], None]] = None,
+ save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None,
+ resume_from_log: bool = False,
+ reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
+ logger: BaseLogger = LazyLogger(),
+ verbose: bool = True,
+ show_progress: bool = True,
+ test_in_train: bool = True,
+ **kwargs: Any,
+ ):
+ super().__init__(
+ learning_type="onpolicy",
+ policy=policy,
+ train_collector=train_collector,
+ test_collector=test_collector,
+ max_epoch=max_epoch,
+ step_per_epoch=step_per_epoch,
+ repeat_per_collect=repeat_per_collect,
+ episode_per_test=episode_per_test,
+ batch_size=batch_size,
+ step_per_collect=step_per_collect,
+ episode_per_collect=episode_per_collect,
+ train_fn=train_fn,
+ test_fn=test_fn,
+ stop_fn=stop_fn,
+ save_best_fn=save_best_fn,
+ save_checkpoint_fn=save_checkpoint_fn,
+ resume_from_log=resume_from_log,
+ reward_metric=reward_metric,
+ logger=logger,
+ verbose=verbose,
+ show_progress=show_progress,
+ test_in_train=test_in_train,
+ **kwargs,
+ )
+
+ def policy_update_fn(
+ self, data: Dict[str, Any], result: Optional[Dict[str, Any]] = None
+ ) -> None:
+ """Perform one on-policy update."""
+ assert self.train_collector is not None
+ losses = self.policy.update(
+ 0,
+ self.train_collector.buffer,
+ batch_size=self.batch_size,
+ repeat=self.repeat_per_collect,
+ )
+ self.train_collector.reset_buffer(keep_statistics=True)
+ step = max([1] + [len(v) for v in losses.values() if isinstance(v, list)])
+ self.gradient_step += step
+ self.log_update_data(data, losses)
+
+
+def onpolicy_trainer(*args, **kwargs) -> Dict[str, Union[float, str]]: # type: ignore
+ """Wrapper for OnpolicyTrainer run method.
+
+ It is identical to ``OnpolicyTrainer(...).run()``.
+
+ :return: See :func:`~tianshou.trainer.gather_info`.
+ """
+ return OnpolicyTrainer(*args, **kwargs).run()
+
+
+onpolicy_trainer_iter = OnpolicyTrainer
diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py
new file mode 100644
index 0000000..585fdf2
--- /dev/null
+++ b/tianshou/trainer/utils.py
@@ -0,0 +1,98 @@
+import time
+from typing import Any, Callable, Dict, Optional, Union
+
+import numpy as np
+
+from tianshou.data import Collector
+from tianshou.policy import BasePolicy
+from tianshou.utils import BaseLogger
+
+
+def test_episode(
+ policy: BasePolicy,
+ collector: Collector,
+ test_fn: Optional[Callable[[int, Optional[int]], None]],
+ epoch: int,
+ n_episode: int,
+ logger: Optional[BaseLogger] = None,
+ global_step: Optional[int] = None,
+ reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
+) -> Dict[str, Any]:
+ """A simple wrapper of testing policy in collector."""
+ collector.reset_env()
+ collector.reset_buffer()
+ policy.eval()
+ if test_fn:
+ test_fn(epoch, global_step)
+ result = collector.collect(n_episode=n_episode)
+ if reward_metric:
+ rew = reward_metric(result["rews"])
+ result.update(rews=rew, rew=rew.mean(), rew_std=rew.std())
+ if logger and global_step is not None:
+ logger.log_test_data(result, global_step)
+ return result
+
+
+def gather_info(
+ start_time: float,
+ train_collector: Optional[Collector],
+ test_collector: Optional[Collector],
+ best_reward: float,
+ best_reward_std: float,
+) -> Dict[str, Union[float, str]]:
+ """A simple wrapper of gathering information from collectors.
+
+ :return: A dictionary with the following keys:
+
+ * ``train_step`` the total collected step of training collector;
+ * ``train_episode`` the total collected episode of training collector;
+ * ``train_time/collector`` the time for collecting transitions in the \
+ training collector;
+ * ``train_time/model`` the time for training models;
+ * ``train_speed`` the speed of training (env_step per second);
+ * ``test_step`` the total collected step of test collector;
+ * ``test_episode`` the total collected episode of test collector;
+ * ``test_time`` the time for testing;
+ * ``test_speed`` the speed of testing (env_step per second);
+ * ``best_reward`` the best reward over the test results;
+ * ``duration`` the total elapsed time.
+ """
+ duration = max(0, time.time() - start_time)
+ model_time = duration
+ result: Dict[str, Union[float, str]] = {
+ "duration": f"{duration:.2f}s",
+ "train_time/model": f"{model_time:.2f}s",
+ }
+ if test_collector is not None:
+ model_time = max(0, duration - test_collector.collect_time)
+ test_speed = test_collector.collect_step / test_collector.collect_time
+ result.update(
+ {
+ "test_step": test_collector.collect_step,
+ "test_episode": test_collector.collect_episode,
+ "test_time": f"{test_collector.collect_time:.2f}s",
+ "test_speed": f"{test_speed:.2f} step/s",
+ "best_reward": best_reward,
+ "best_result": f"{best_reward:.2f} ± {best_reward_std:.2f}",
+ "duration": f"{duration:.2f}s",
+ "train_time/model": f"{model_time:.2f}s",
+ }
+ )
+ if train_collector is not None:
+ model_time = max(0, model_time - train_collector.collect_time)
+ if test_collector is not None:
+ train_speed = train_collector.collect_step / (
+ duration - test_collector.collect_time
+ )
+ else:
+ train_speed = train_collector.collect_step / duration
+ result.update(
+ {
+ "train_step": train_collector.collect_step,
+ "train_episode": train_collector.collect_episode,
+ "train_time/collector": f"{train_collector.collect_time:.2f}s",
+ "train_time/model": f"{model_time:.2f}s",
+ "train_speed": f"{train_speed:.2f} step/s",
+ }
+ )
+ return result
diff --git a/tianshou/utils/__init__.py b/tianshou/utils/__init__.py
new file mode 100644
index 0000000..78eb29d
--- /dev/null
+++ b/tianshou/utils/__init__.py
@@ -0,0 +1,23 @@
+"""Utils package."""
+
+from tianshou.utils.logger.base import BaseLogger, LazyLogger
+from tianshou.utils.logger.tensorboard import BasicLogger, TensorboardLogger
+from tianshou.utils.logger.wandb import WandbLogger
+from tianshou.utils.lr_scheduler import MultipleLRSchedulers
+from tianshou.utils.progress_bar import DummyTqdm, tqdm_config
+from tianshou.utils.statistics import MovAvg, RunningMeanStd
+from tianshou.utils.warning import deprecation
+
+__all__ = [
+ "MovAvg",
+ "RunningMeanStd",
+ "tqdm_config",
+ "DummyTqdm",
+ "BaseLogger",
+ "TensorboardLogger",
+ "BasicLogger",
+ "LazyLogger",
+ "WandbLogger",
+ "deprecation",
+ "MultipleLRSchedulers",
+]
diff --git a/tianshou/utils/logger/__init__.py b/tianshou/utils/logger/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tianshou/utils/logger/base.py b/tianshou/utils/logger/base.py
new file mode 100644
index 0000000..d29792a
--- /dev/null
+++ b/tianshou/utils/logger/base.py
@@ -0,0 +1,144 @@
+from abc import ABC, abstractmethod
+from numbers import Number
+from typing import Callable, Dict, Optional, Tuple, Union
+
+import numpy as np
+
+LOG_DATA_TYPE = Dict[str, Union[int, Number, np.number, np.ndarray]]
+
+
+class BaseLogger(ABC):
+ """The base class for any logger which is compatible with trainer.
+
+ Try to overwrite write() method to use your own writer.
+
+ :param int train_interval: the log interval in log_train_data(). Default to 1000.
+ :param int test_interval: the log interval in log_test_data(). Default to 1.
+ :param int update_interval: the log interval in log_update_data(). Default to 1000.
+ """
+
+ def __init__(
+ self,
+ train_interval: int = 1000,
+ test_interval: int = 1,
+ update_interval: int = 1000,
+ ) -> None:
+ super().__init__()
+ self.train_interval = train_interval
+ self.test_interval = test_interval
+ self.update_interval = update_interval
+ self.last_log_train_step = -1
+ self.last_log_test_step = -1
+ self.last_log_update_step = -1
+
+ @abstractmethod
+ def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
+ """Specify how the writer is used to log data.
+
+ :param str step_type: namespace which the data dict belongs to.
+ :param int step: stands for the ordinate of the data dict.
+ :param dict data: the data to write with format ``{key: value}``.
+ """
+ pass
+
+ def log_train_data(self, collect_result: dict, step: int) -> None:
+ """Use writer to log statistics generated during training.
+
+ :param collect_result: a dict containing information of data collected in
+ training stage, i.e., returns of collector.collect().
+ :param int step: stands for the timestep the collect_result being logged.
+ """
+ if collect_result["n/ep"] > 0:
+ if step - self.last_log_train_step >= self.train_interval:
+ log_data = {
+ "train/episode": collect_result["n/ep"],
+ "train/reward": collect_result["rew"],
+ "train/num": collect_result["num"],
+ "train/ratio": collect_result["ratio"],
+ }
+ self.write("train/env_step", step, log_data)
+ self.last_log_train_step = step
+
+ def log_test_data(self, collect_result: dict, step: int) -> None:
+ """Use writer to log statistics generated during evaluating.
+
+ :param collect_result: a dict containing information of data collected in
+ evaluating stage, i.e., returns of collector.collect().
+ :param int step: stands for the timestep the collect_result being logged.
+ """
+ assert collect_result["n/ep"] > 0
+ if step - self.last_log_test_step >= self.test_interval:
+ log_data = {
+ # "test/env_step": step,
+ "test/reward": collect_result["rew"],
+ "test/num": collect_result["num"],
+ "test/reward_std": collect_result["rew_std"],
+ "test/num_std": collect_result["num_std"],
+ "test/ratio": collect_result["ratio"],
+ }
+ self.write("test/env_step", step, log_data)
+ self.last_log_test_step = step
+
+ def log_update_data(self, update_result: dict, step: int) -> None:
+ """Use writer to log statistics generated during updating.
+
+ :param update_result: a dict containing information of data collected in
+ updating stage, i.e., returns of policy.update().
+ :param int step: stands for the timestep the collect_result being logged.
+ """
+ if step - self.last_log_update_step >= self.update_interval:
+ log_data = {f"update/{k}": v for k, v in update_result.items()}
+ self.write("update/gradient_step", step, log_data)
+ self.last_log_update_step = step
+
+ @abstractmethod
+ def save_data(
+ self,
+ epoch: int,
+ env_step: int,
+ gradient_step: int,
+ save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None,
+ ) -> None:
+ """Use writer to log metadata when calling ``save_checkpoint_fn`` in trainer.
+
+ :param int epoch: the epoch in trainer.
+ :param int env_step: the env_step in trainer.
+ :param int gradient_step: the gradient_step in trainer.
+ :param function save_checkpoint_fn: a hook defined by user, see trainer
+ documentation for detail.
+ """
+ pass
+
+ @abstractmethod
+ def restore_data(self) -> Tuple[int, int, int]:
+ """Return the metadata from existing log.
+
+ If it finds nothing or an error occurs during the recover process, it will
+ return the default parameters.
+
+ :return: epoch, env_step, gradient_step.
+ """
+ pass
+
+
+class LazyLogger(BaseLogger):
+ """A logger that does nothing. Used as the placeholder in trainer."""
+
+ def __init__(self) -> None:
+ super().__init__()
+
+ def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
+ """The LazyLogger writes nothing."""
+ pass
+
+ def save_data(
+ self,
+ epoch: int,
+ env_step: int,
+ gradient_step: int,
+ save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None,
+ ) -> None:
+ pass
+
+ def restore_data(self) -> Tuple[int, int, int]:
+ return 0, 0, 0
diff --git a/tianshou/utils/logger/tensorboard.py b/tianshou/utils/logger/tensorboard.py
new file mode 100644
index 0000000..2583673
--- /dev/null
+++ b/tianshou/utils/logger/tensorboard.py
@@ -0,0 +1,93 @@
+from typing import Any, Callable, Optional, Tuple
+
+from tensorboard.backend.event_processing import event_accumulator
+from torch.utils.tensorboard import SummaryWriter
+
+from tianshou.utils.logger.base import LOG_DATA_TYPE, BaseLogger
+from tianshou.utils.warning import deprecation
+
+
+class TensorboardLogger(BaseLogger):
+ """A logger that relies on tensorboard SummaryWriter by default to visualize \
+ and log statistics.
+
+ :param SummaryWriter writer: the writer to log data.
+ :param int train_interval: the log interval in log_train_data(). Default to 1000.
+ :param int test_interval: the log interval in log_test_data(). Default to 1.
+ :param int update_interval: the log interval in log_update_data(). Default to 1000.
+ :param int save_interval: the save interval in save_data(). Default to 1 (save at
+ the end of each epoch).
+ :param bool write_flush: whether to flush tensorboard result after each
+ add_scalar operation. Default to True.
+ """
+
+ def __init__(
+ self,
+ writer: SummaryWriter,
+ train_interval: int = 1000,
+ test_interval: int = 1,
+ update_interval: int = 1000,
+ save_interval: int = 1,
+ write_flush: bool = True,
+ ) -> None:
+ super().__init__(train_interval, test_interval, update_interval)
+ self.save_interval = save_interval
+ self.write_flush = write_flush
+ self.last_save_step = -1
+ self.writer = writer
+
+ def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
+ for k, v in data.items():
+ self.writer.add_scalar(k, v, global_step=step)
+ if self.write_flush: # issue 580
+ self.writer.flush() # issue #482
+
+ def save_data(
+ self,
+ epoch: int,
+ env_step: int,
+ gradient_step: int,
+ save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None,
+ ) -> None:
+ if save_checkpoint_fn and epoch - self.last_save_step >= self.save_interval:
+ self.last_save_step = epoch
+ save_checkpoint_fn(epoch, env_step, gradient_step)
+ self.write("save/epoch", epoch, {"save/epoch": epoch})
+ self.write("save/env_step", env_step, {"save/env_step": env_step})
+ self.write(
+ "save/gradient_step", gradient_step,
+ {"save/gradient_step": gradient_step}
+ )
+
+ def restore_data(self) -> Tuple[int, int, int]:
+ ea = event_accumulator.EventAccumulator(self.writer.log_dir)
+ ea.Reload()
+
+ try: # epoch / gradient_step
+ epoch = ea.scalars.Items("save/epoch")[-1].step
+ self.last_save_step = self.last_log_test_step = epoch
+ gradient_step = ea.scalars.Items("save/gradient_step")[-1].step
+ self.last_log_update_step = gradient_step
+ except KeyError:
+ epoch, gradient_step = 0, 0
+ try: # offline trainer doesn't have env_step
+ env_step = ea.scalars.Items("save/env_step")[-1].step
+ self.last_log_train_step = env_step
+ except KeyError:
+ env_step = 0
+
+ return epoch, env_step, gradient_step
+
+
+class BasicLogger(TensorboardLogger):
+ """BasicLogger has changed its name to TensorboardLogger in #427.
+
+ This class is for compatibility.
+ """
+
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
+ deprecation(
+ "Class BasicLogger is marked as deprecated and will be removed soon. "
+ "Please use TensorboardLogger instead."
+ )
+ super().__init__(*args, **kwargs)
diff --git a/tianshou/utils/logger/wandb.py b/tianshou/utils/logger/wandb.py
new file mode 100644
index 0000000..c8e78d5
--- /dev/null
+++ b/tianshou/utils/logger/wandb.py
@@ -0,0 +1,149 @@
+import argparse
+import os
+from typing import Callable, Optional, Tuple
+
+from torch.utils.tensorboard import SummaryWriter
+
+from tianshou.utils import BaseLogger, TensorboardLogger
+from tianshou.utils.logger.base import LOG_DATA_TYPE
+
+try:
+ import wandb
+except ImportError:
+ pass
+
+
+class WandbLogger(BaseLogger):
+ """Weights and Biases logger that sends data to https://wandb.ai/.
+
+ This logger creates three panels with plots: train, test, and update.
+ Make sure to select the correct access for each panel in weights and biases:
+
+ Example of usage:
+ ::
+
+ logger = WandbLogger()
+ logger.load(SummaryWriter(log_path))
+ result = onpolicy_trainer(policy, train_collector, test_collector,
+ logger=logger)
+
+ :param int train_interval: the log interval in log_train_data(). Default to 1000.
+ :param int test_interval: the log interval in log_test_data(). Default to 1.
+ :param int update_interval: the log interval in log_update_data().
+ Default to 1000.
+ :param int save_interval: the save interval in save_data(). Default to 1 (save at
+ the end of each epoch).
+ :param bool write_flush: whether to flush tensorboard result after each
+ add_scalar operation. Default to True.
+ :param str project: W&B project name. Default to "tianshou".
+ :param str name: W&B run name. Default to None. If None, random name is assigned.
+ :param str entity: W&B team/organization name. Default to None.
+ :param str run_id: run id of W&B run to be resumed. Default to None.
+ :param argparse.Namespace config: experiment configurations. Default to None.
+ """
+
+ def __init__(
+ self,
+ train_interval: int = 1000,
+ test_interval: int = 1,
+ update_interval: int = 1000,
+ save_interval: int = 1000,
+ write_flush: bool = True,
+ project: Optional[str] = None,
+ name: Optional[str] = None,
+ entity: Optional[str] = None,
+ run_id: Optional[str] = None,
+ config: Optional[argparse.Namespace] = None,
+ ) -> None:
+ super().__init__(train_interval, test_interval, update_interval)
+ self.last_save_step = -1
+ self.save_interval = save_interval
+ self.write_flush = write_flush
+ self.restored = False
+ if project is None:
+ project = os.getenv("WANDB_PROJECT", "tianshou")
+
+ self.wandb_run = wandb.init(
+ project=project,
+ name=name,
+ id=run_id,
+ resume="allow",
+ entity=entity,
+ sync_tensorboard=True,
+ monitor_gym=True,
+ config=config, # type: ignore
+ ) if not wandb.run else wandb.run
+ self.wandb_run._label(repo="tianshou") # type: ignore
+ self.tensorboard_logger: Optional[TensorboardLogger] = None
+
+ def load(self, writer: SummaryWriter) -> None:
+ self.writer = writer
+ self.tensorboard_logger = TensorboardLogger(
+ writer, self.train_interval, self.test_interval, self.update_interval,
+ self.save_interval, self.write_flush
+ )
+
+ def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
+ if self.tensorboard_logger is None:
+ raise Exception(
+ "`logger` needs to load the Tensorboard Writer before "
+ "writing data. Try `logger.load(SummaryWriter(log_path))`"
+ )
+ else:
+ self.tensorboard_logger.write(step_type, step, data)
+
+ def save_data(
+ self,
+ epoch: int,
+ env_step: int,
+ gradient_step: int,
+ save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None,
+ ) -> None:
+ """Use writer to log metadata when calling ``save_checkpoint_fn`` in trainer.
+
+ :param int epoch: the epoch in trainer.
+ :param int env_step: the env_step in trainer.
+ :param int gradient_step: the gradient_step in trainer.
+ :param function save_checkpoint_fn: a hook defined by user, see trainer
+ documentation for detail.
+ """
+ if save_checkpoint_fn and epoch - self.last_save_step >= self.save_interval:
+ self.last_save_step = epoch
+ checkpoint_path = save_checkpoint_fn(epoch, env_step, gradient_step)
+
+ checkpoint_artifact = wandb.Artifact(
+ 'run_' + self.wandb_run.id + '_checkpoint', # type: ignore
+ type='model',
+ metadata={
+ "save/epoch": epoch,
+ "save/env_step": env_step,
+ "save/gradient_step": gradient_step,
+ "checkpoint_path": str(checkpoint_path),
+ }
+ )
+ checkpoint_artifact.add_file(str(checkpoint_path))
+ self.wandb_run.log_artifact(checkpoint_artifact) # type: ignore
+
+ def restore_data(self) -> Tuple[int, int, int]:
+ checkpoint_artifact = self.wandb_run.use_artifact( # type: ignore
+ f"run_{self.wandb_run.id}_checkpoint:latest" # type: ignore
+ )
+ assert checkpoint_artifact is not None, "W&B dataset artifact doesn't exist"
+
+ checkpoint_artifact.download(
+ os.path.dirname(checkpoint_artifact.metadata['checkpoint_path'])
+ )
+
+ try: # epoch / gradient_step
+ epoch = checkpoint_artifact.metadata["save/epoch"]
+ self.last_save_step = self.last_log_test_step = epoch
+ gradient_step = checkpoint_artifact.metadata["save/gradient_step"]
+ self.last_log_update_step = gradient_step
+ except KeyError:
+ epoch, gradient_step = 0, 0
+ try: # offline trainer doesn't have env_step
+ env_step = checkpoint_artifact.metadata["save/env_step"]
+ self.last_log_train_step = env_step
+ except KeyError:
+ env_step = 0
+ return epoch, env_step, gradient_step
diff --git a/tianshou/utils/lr_scheduler.py b/tianshou/utils/lr_scheduler.py
new file mode 100644
index 0000000..36a08b2
--- /dev/null
+++ b/tianshou/utils/lr_scheduler.py
@@ -0,0 +1,42 @@
+from typing import Dict, List
+
+import torch
+
+
+class MultipleLRSchedulers:
+ """A wrapper for multiple learning rate schedulers.
+
+ Every time :meth:`~tianshou.utils.MultipleLRSchedulers.step` is called,
+ it calls the step() method of each of the schedulers that it contains.
+ Example usage:
+ ::
+
+ scheduler1 = ConstantLR(opt1, factor=0.1, total_iters=2)
+ scheduler2 = ExponentialLR(opt2, gamma=0.9)
+ scheduler = MultipleLRSchedulers(scheduler1, scheduler2)
+ policy = PPOPolicy(..., lr_scheduler=scheduler)
+ """
+
+ def __init__(self, *args: torch.optim.lr_scheduler.LambdaLR):
+ self.schedulers = args
+
+ def step(self) -> None:
+ """Take a step in each of the learning rate schedulers."""
+ for scheduler in self.schedulers:
+ scheduler.step()
+
+ def state_dict(self) -> List[Dict]:
+ """Get state_dict for each of the learning rate schedulers.
+
+ :return: A list of state_dict of learning rate schedulers.
+ """
+ return [s.state_dict() for s in self.schedulers]
+
+ def load_state_dict(self, state_dict: List[Dict]) -> None:
+ """Load states from state_dict.
+
+ :param List[Dict] state_dict: A list of learning rate scheduler
+ state_dict, in the same order as the schedulers.
+ """
+ for (s, sd) in zip(self.schedulers, state_dict):
+ s.__dict__.update(sd)
diff --git a/tianshou/utils/net/__init__.py b/tianshou/utils/net/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py
new file mode 100644
index 0000000..1ce3946
--- /dev/null
+++ b/tianshou/utils/net/common.py
@@ -0,0 +1,553 @@
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ List,
+ Optional,
+ Sequence,
+ Tuple,
+ Type,
+ Union,
+ no_type_check,
+)
+
+import numpy as np
+import torch
+from torch import nn
+
+from tianshou.data.batch import Batch
+
+ModuleType = Type[nn.Module]
+ArgsType = Union[Tuple[Any, ...], Dict[Any, Any], Sequence[Tuple[Any, ...]],
+ Sequence[Dict[Any, Any]]]
+
+
+def miniblock(
+ input_size: int,
+ output_size: int = 0,
+ norm_layer: Optional[ModuleType] = None,
+ norm_args: Optional[Union[Tuple[Any, ...], Dict[Any, Any]]] = None,
+ activation: Optional[ModuleType] = None,
+ act_args: Optional[Union[Tuple[Any, ...], Dict[Any, Any]]] = None,
+ linear_layer: Type[nn.Linear] = nn.Linear,
+) -> List[nn.Module]:
+ """Construct a miniblock with given input/output-size, norm layer and \
+ activation."""
+ layers: List[nn.Module] = [linear_layer(input_size, output_size)]
+ if norm_layer is not None:
+ if isinstance(norm_args, tuple):
+ layers += [norm_layer(output_size, *norm_args)] # type: ignore
+ elif isinstance(norm_args, dict):
+ layers += [norm_layer(output_size, **norm_args)] # type: ignore
+ else:
+ layers += [norm_layer(output_size)] # type: ignore
+ if activation is not None:
+ if isinstance(act_args, tuple):
+ layers += [activation(*act_args)]
+ elif isinstance(act_args, dict):
+ layers += [activation(**act_args)]
+ else:
+ layers += [activation()]
+ return layers
+
+
+class MLP(nn.Module):
+ """Simple MLP backbone.
+
+ Create a MLP of size input_dim * hidden_sizes[0] * hidden_sizes[1] * ...
+ * hidden_sizes[-1] * output_dim
+
+ :param int input_dim: dimension of the input vector.
+ :param int output_dim: dimension of the output vector. If set to 0, there
+ is no final linear layer.
+ :param hidden_sizes: shape of MLP passed in as a list, not including
+ input_dim and output_dim.
+ :param norm_layer: use which normalization before activation, e.g.,
+ ``nn.LayerNorm`` and ``nn.BatchNorm1d``. Default to no normalization.
+ You can also pass a list of normalization modules with the same length
+ of hidden_sizes, to use different normalization module in different
+ layers. Default to no normalization.
+ :param activation: which activation to use after each layer, can be both
+ the same activation for all layers if passed in nn.Module, or different
+ activation for different Modules if passed in a list. Default to
+ nn.ReLU.
+ :param device: which device to create this model on. Default to None.
+ :param linear_layer: use this module as linear layer. Default to nn.Linear.
+ :param bool flatten_input: whether to flatten input data. Default to True.
+ """
+
+ def __init__(
+ self,
+ input_dim: int,
+ output_dim: int = 0,
+ hidden_sizes: Sequence[int] = (),
+ norm_layer: Optional[Union[ModuleType, Sequence[ModuleType]]] = None,
+ norm_args: Optional[ArgsType] = None,
+ activation: Optional[Union[ModuleType, Sequence[ModuleType]]] = nn.ReLU,
+ act_args: Optional[ArgsType] = None,
+ device: Optional[Union[str, int, torch.device]] = None,
+ linear_layer: Type[nn.Linear] = nn.Linear,
+ flatten_input: bool = True,
+ ) -> None:
+ super().__init__()
+ self.device = device
+ if norm_layer:
+ if isinstance(norm_layer, list):
+ assert len(norm_layer) == len(hidden_sizes)
+ norm_layer_list = norm_layer
+ if isinstance(norm_args, list):
+ assert len(norm_args) == len(hidden_sizes)
+ norm_args_list = norm_args
+ else:
+ norm_args_list = [norm_args for _ in range(len(hidden_sizes))]
+ else:
+ norm_layer_list = [norm_layer for _ in range(len(hidden_sizes))]
+ norm_args_list = [norm_args for _ in range(len(hidden_sizes))]
+ else:
+ norm_layer_list = [None] * len(hidden_sizes)
+ norm_args_list = [None] * len(hidden_sizes)
+ if activation:
+ if isinstance(activation, list):
+ assert len(activation) == len(hidden_sizes)
+ activation_list = activation
+ if isinstance(act_args, list):
+ assert len(act_args) == len(hidden_sizes)
+ act_args_list = act_args
+ else:
+ act_args_list = [act_args for _ in range(len(hidden_sizes))]
+ else:
+ activation_list = [activation for _ in range(len(hidden_sizes))]
+ act_args_list = [act_args for _ in range(len(hidden_sizes))]
+ else:
+ activation_list = [None] * len(hidden_sizes)
+ act_args_list = [None] * len(hidden_sizes)
+ hidden_sizes = [input_dim] + list(hidden_sizes)
+ model = []
+ for in_dim, out_dim, norm, norm_args, activ, act_args in zip(
+ hidden_sizes[:-1], hidden_sizes[1:], norm_layer_list, norm_args_list,
+ activation_list, act_args_list
+ ):
+ model += miniblock(
+ in_dim, out_dim, norm, norm_args, activ, act_args, linear_layer
+ )
+ if output_dim > 0:
+ model += [linear_layer(hidden_sizes[-1], output_dim)]
+ self.output_dim = output_dim or hidden_sizes[-1]
+ self.model = nn.Sequential(*model)
+ self.flatten_input = flatten_input
+
+ @no_type_check
+ def forward(self, obs: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
+ if self.device is not None:
+ obs = torch.as_tensor(obs, device=self.device, dtype=torch.float32)
+ if self.flatten_input:
+ obs = obs.flatten(1)
+ return self.model(obs)
+
+
+class Net(nn.Module):
+ """Wrapper of MLP to support more specific DRL usage.
+
+ For advanced usage (how to customize the network), please refer to
+ :ref:`build_the_network`.
+
+ :param state_shape: int or a sequence of int of the shape of state.
+ :param action_shape: int or a sequence of int of the shape of action.
+ :param hidden_sizes: shape of MLP passed in as a list.
+ :param norm_layer: use which normalization before activation, e.g.,
+ ``nn.LayerNorm`` and ``nn.BatchNorm1d``. Default to no normalization.
+ You can also pass a list of normalization modules with the same length
+ of hidden_sizes, to use different normalization module in different
+ layers. Default to no normalization.
+ :param activation: which activation to use after each layer, can be both
+ the same activation for all layers if passed in nn.Module, or different
+ activation for different Modules if passed in a list. Default to
+ nn.ReLU.
+ :param device: specify the device when the network actually runs. Default
+ to "cpu".
+ :param bool softmax: whether to apply a softmax layer over the last layer's
+ output.
+ :param bool concat: whether the input shape is concatenated by state_shape
+ and action_shape. If it is True, ``action_shape`` is not the output
+ shape, but affects the input shape only.
+ :param int num_atoms: in order to expand to the net of distributional RL.
+ Default to 1 (not use).
+ :param bool dueling_param: whether to use dueling network to calculate Q
+ values (for Dueling DQN). If you want to use dueling option, you should
+ pass a tuple of two dict (first for Q and second for V) stating
+ self-defined arguments as stated in
+ class:`~tianshou.utils.net.common.MLP`. Default to None.
+ :param linear_layer: use this module as linear layer. Default to nn.Linear.
+
+ .. seealso::
+
+ Please refer to :class:`~tianshou.utils.net.common.MLP` for more
+ detailed explanation on the usage of activation, norm_layer, etc.
+
+ You can also refer to :class:`~tianshou.utils.net.continuous.Actor`,
+ :class:`~tianshou.utils.net.continuous.Critic`, etc, to see how it's
+ suggested be used.
+ """
+
+ def __init__(
+ self,
+ state_shape: Union[int, Sequence[int]],
+ action_shape: Union[int, Sequence[int]] = 0,
+ hidden_sizes: Sequence[int] = (),
+ norm_layer: Optional[Union[ModuleType, Sequence[ModuleType]]] = None,
+ norm_args: Optional[ArgsType] = None,
+ activation: Optional[Union[ModuleType, Sequence[ModuleType]]] = nn.ReLU,
+ act_args: Optional[ArgsType] = None,
+ device: Union[str, int, torch.device] = "cpu",
+ softmax: bool = False,
+ concat: bool = False,
+ num_atoms: int = 1,
+ dueling_param: Optional[Tuple[Dict[str, Any], Dict[str, Any]]] = None,
+ linear_layer: Type[nn.Linear] = nn.Linear,
+ ) -> None:
+ super().__init__()
+ self.device = device
+ self.softmax = softmax
+ self.num_atoms = num_atoms
+ input_dim = int(np.prod(state_shape))
+ action_dim = int(np.prod(action_shape)) * num_atoms
+ if concat:
+ input_dim += action_dim
+ self.use_dueling = dueling_param is not None
+ output_dim = action_dim if not self.use_dueling and not concat else 0
+ self.model = MLP(
+ input_dim, output_dim, hidden_sizes, norm_layer, norm_args, activation,
+ act_args, device, linear_layer
+ )
+ self.output_dim = self.model.output_dim
+ if self.use_dueling: # dueling DQN
+ q_kwargs, v_kwargs = dueling_param # type: ignore
+ q_output_dim, v_output_dim = 0, 0
+ if not concat:
+ q_output_dim, v_output_dim = action_dim, num_atoms
+ q_kwargs: Dict[str, Any] = {
+ **q_kwargs, "input_dim": self.output_dim,
+ "output_dim": q_output_dim,
+ "device": self.device
+ }
+ v_kwargs: Dict[str, Any] = {
+ **v_kwargs, "input_dim": self.output_dim,
+ "output_dim": v_output_dim,
+ "device": self.device
+ }
+ self.Q, self.V = MLP(**q_kwargs), MLP(**v_kwargs)
+ self.output_dim = self.Q.output_dim
+
+ def forward(
+ self,
+ obs: Union[np.ndarray, torch.Tensor],
+ state: Any = None,
+ info: Dict[str, Any] = {},
+ ) -> Tuple[torch.Tensor, Any]:
+ """Mapping: obs -> flatten (inside MLP)-> logits."""
+ logits = self.model(obs)
+ bsz = logits.shape[0]
+ if self.use_dueling: # Dueling DQN
+ q, v = self.Q(logits), self.V(logits)
+ if self.num_atoms > 1:
+ q = q.view(bsz, -1, self.num_atoms)
+ v = v.view(bsz, -1, self.num_atoms)
+ logits = q - q.mean(dim=1, keepdim=True) + v
+ elif self.num_atoms > 1:
+ logits = logits.view(bsz, -1, self.num_atoms)
+ if self.softmax:
+ logits = torch.softmax(logits, dim=-1)
+ return logits, state
+
+
+class Recurrent(nn.Module):
+ """Simple Recurrent network based on LSTM.
+
+ For advanced usage (how to customize the network), please refer to
+ :ref:`build_the_network`.
+ """
+
+ def __init__(
+ self,
+ layer_num: int,
+ state_shape: Union[int, Sequence[int]],
+ action_shape: Union[int, Sequence[int]],
+ device: Union[str, int, torch.device] = "cpu",
+ hidden_layer_size: int = 128,
+ ) -> None:
+ super().__init__()
+ self.device = device
+ self.nn = nn.LSTM(
+ input_size=hidden_layer_size,
+ hidden_size=hidden_layer_size,
+ num_layers=layer_num,
+ batch_first=True,
+ )
+ self.fc1 = nn.Linear(int(np.prod(state_shape)), hidden_layer_size)
+ self.fc2 = nn.Linear(hidden_layer_size, int(np.prod(action_shape)))
+
+ def forward(
+ self,
+ obs: Union[np.ndarray, torch.Tensor],
+ state: Optional[Dict[str, torch.Tensor]] = None,
+ info: Dict[str, Any] = {},
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
+ """Mapping: obs -> flatten -> logits.
+
+ In the evaluation mode, `obs` should be with shape ``[bsz, dim]``; in the
+ training mode, `obs` should be with shape ``[bsz, len, dim]``. See the code
+ and comment for more detail.
+ """
+ obs = torch.as_tensor(
+ obs,
+ device=self.device,
+ dtype=torch.float32,
+ )
+ # obs [bsz, len, dim] (training) or [bsz, dim] (evaluation)
+ # In short, the tensor's shape in training phase is longer than which
+ # in evaluation phase.
+ if len(obs.shape) == 2:
+ obs = obs.unsqueeze(-2)
+ obs = self.fc1(obs)
+ self.nn.flatten_parameters()
+ if state is None:
+ obs, (hidden, cell) = self.nn(obs)
+ else:
+ # we store the stack data in [bsz, len, ...] format
+ # but pytorch rnn needs [len, bsz, ...]
+ obs, (hidden, cell) = self.nn(
+ obs, (
+ state["hidden"].transpose(0, 1).contiguous(),
+ state["cell"].transpose(0, 1).contiguous()
+ )
+ )
+ obs = self.fc2(obs[:, -1])
+ # please ensure the first dim is batch size: [bsz, len, ...]
+ return obs, {
+ "hidden": hidden.transpose(0, 1).detach(),
+ "cell": cell.transpose(0, 1).detach()
+ }
+
+
+class ActorCritic(nn.Module):
+ """An actor-critic network for parsing parameters.
+
+ Using ``actor_critic.parameters()`` instead of set.union or list+list to avoid
+ issue #449.
+
+ :param nn.Module actor: the actor network.
+ :param nn.Module critic: the critic network.
+ """
+
+ def __init__(self, actor: nn.Module, critic: nn.Module) -> None:
+ super().__init__()
+ self.actor = actor
+ self.critic = critic
+
+
+class DataParallelNet(nn.Module):
+ """DataParallel wrapper for training agent with multi-GPU.
+
+ This class does only the conversion of input data type, from numpy array to torch's
+ Tensor. If the input is a nested dictionary, the user should create a similar class
+ to do the same thing.
+
+ :param nn.Module net: the network to be distributed in different GPUs.
+ """
+
+ def __init__(self, net: nn.Module) -> None:
+ super().__init__()
+ self.net = nn.DataParallel(net)
+
+ def forward(self, obs: Union[np.ndarray, torch.Tensor], *args: Any,
+ **kwargs: Any) -> Tuple[Any, Any]:
+ if not isinstance(obs, torch.Tensor):
+ obs = torch.as_tensor(obs, dtype=torch.float32)
+ return self.net(obs=obs.cuda(), *args, **kwargs)
+
+
+class EnsembleLinear(nn.Module):
+ """Linear Layer of Ensemble network.
+
+ :param int ensemble_size: Number of subnets in the ensemble.
+ :param int inp_feature: dimension of the input vector.
+ :param int out_feature: dimension of the output vector.
+ :param bool bias: whether to include an additive bias, default to be True.
+ """
+
+ def __init__(
+ self,
+ ensemble_size: int,
+ in_feature: int,
+ out_feature: int,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+
+ # To be consistent with PyTorch default initializer
+ k = np.sqrt(1. / in_feature)
+ weight_data = torch.rand((ensemble_size, in_feature, out_feature)) * 2 * k - k
+ self.weight = nn.Parameter(weight_data, requires_grad=True)
+
+ self.bias: Union[nn.Parameter, None]
+ if bias:
+ bias_data = torch.rand((ensemble_size, 1, out_feature)) * 2 * k - k
+ self.bias = nn.Parameter(bias_data, requires_grad=True)
+ else:
+ self.bias = None
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = torch.matmul(x, self.weight)
+ if self.bias is not None:
+ x = x + self.bias
+ return x
+
+
+class BranchingNet(nn.Module):
+ """Branching dual Q network.
+
+ Network for the BranchingDQNPolicy, it uses a common network module, a value module
+ and action "branches" one for each dimension.It allows for a linear scaling
+ of Q-value the output w.r.t. the number of dimensions in the action space.
+ For more info please refer to: arXiv:1711.08946.
+ :param state_shape: int or a sequence of int of the shape of state.
+ :param action_shape: int or a sequence of int of the shape of action.
+ :param action_peer_branch: int or a sequence of int of the number of actions in
+ each dimension.
+ :param common_hidden_sizes: shape of the common MLP network passed in as a list.
+ :param value_hidden_sizes: shape of the value MLP network passed in as a list.
+ :param action_hidden_sizes: shape of the action MLP network passed in as a list.
+ :param norm_layer: use which normalization before activation, e.g.,
+ ``nn.LayerNorm`` and ``nn.BatchNorm1d``. Default to no normalization.
+ You can also pass a list of normalization modules with the same length
+ of hidden_sizes, to use different normalization module in different
+ layers. Default to no normalization.
+ :param activation: which activation to use after each layer, can be both
+ the same activation for all layers if passed in nn.Module, or different
+ activation for different Modules if passed in a list. Default to
+ nn.ReLU.
+ :param device: specify the device when the network actually runs. Default
+ to "cpu".
+ :param bool softmax: whether to apply a softmax layer over the last layer's
+ output.
+ """
+
+ def __init__(
+ self,
+ state_shape: Union[int, Sequence[int]],
+ num_branches: int = 0,
+ action_per_branch: int = 2,
+ common_hidden_sizes: List[int] = [],
+ value_hidden_sizes: List[int] = [],
+ action_hidden_sizes: List[int] = [],
+ norm_layer: Optional[ModuleType] = None,
+ norm_args: Optional[ArgsType] = None,
+ activation: Optional[ModuleType] = nn.ReLU,
+ act_args: Optional[ArgsType] = None,
+ device: Union[str, int, torch.device] = "cpu",
+ ) -> None:
+ super().__init__()
+ self.device = device
+ self.num_branches = num_branches
+ self.action_per_branch = action_per_branch
+ # common network
+ common_input_dim = int(np.prod(state_shape))
+ common_output_dim = 0
+ self.common = MLP(
+ common_input_dim, common_output_dim, common_hidden_sizes, norm_layer,
+ norm_args, activation, act_args, device
+ )
+ # value network
+ value_input_dim = common_hidden_sizes[-1]
+ value_output_dim = 1
+ self.value = MLP(
+ value_input_dim, value_output_dim, value_hidden_sizes, norm_layer,
+ norm_args, activation, act_args, device
+ )
+ # action branching network
+ action_input_dim = common_hidden_sizes[-1]
+ action_output_dim = action_per_branch
+ self.branches = nn.ModuleList(
+ [
+ MLP(
+ action_input_dim, action_output_dim, action_hidden_sizes,
+ norm_layer, norm_args, activation, act_args, device
+ ) for _ in range(self.num_branches)
+ ]
+ )
+
+ def forward(
+ self,
+ obs: Union[np.ndarray, torch.Tensor],
+ state: Any = None,
+ info: Dict[str, Any] = {},
+ ) -> Tuple[torch.Tensor, Any]:
+ """Mapping: obs -> model -> logits."""
+ common_out = self.common(obs)
+ value_out = self.value(common_out)
+ value_out = torch.unsqueeze(value_out, 1)
+ action_out = []
+ for b in self.branches:
+ action_out.append(b(common_out))
+ action_scores = torch.stack(action_out, 1)
+ action_scores = action_scores - torch.mean(action_scores, 2, keepdim=True)
+ logits = value_out + action_scores
+ return logits, state
+
+
+def get_dict_state_decorator(
+ state_shape: Dict[str, Union[int, Sequence[int]]], keys: Sequence[str]
+) -> Tuple[Callable, int]:
+ """A helper function to make Net or equivalent classes (e.g. Actor, Critic) \
+ applicable to dict state.
+
+ The first return item, ``decorator_fn``, will alter the implementation of forward
+ function of the given class by preprocessing the observation. The preprocessing is
+ basically flatten the observation and concatenate them based on the ``keys`` order.
+ The batch dimension is preserved if presented. The result observation shape will
+ be equal to ``new_state_shape``, the second return item.
+
+ :param state_shape: A dictionary indicating each state's shape
+ :param keys: A list of state's keys. The flatten observation will be according to \
+ this list order.
+ :returns: a 2-items tuple ``decorator_fn`` and ``new_state_shape``
+ """
+ original_shape = state_shape
+ flat_state_shapes = []
+ for k in keys:
+ flat_state_shapes.append(int(np.prod(state_shape[k])))
+ new_state_shape = sum(flat_state_shapes)
+
+ def preprocess_obs(
+ obs: Union[Batch, dict, torch.Tensor, np.ndarray]
+ ) -> torch.Tensor:
+ if isinstance(obs, dict) or (isinstance(obs, Batch) and keys[0] in obs):
+ if original_shape[keys[0]] == obs[keys[0]].shape:
+ # No batch dim
+ new_obs = torch.Tensor([obs[k] for k in keys]).flatten()
+ # new_obs = torch.Tensor([obs[k] for k in keys]).reshape(1, -1)
+ else:
+ bsz = obs[keys[0]].shape[0]
+ new_obs = torch.cat(
+ [torch.Tensor(obs[k].reshape(bsz, -1)) for k in keys], dim=1
+ )
+ else:
+ new_obs = torch.Tensor(obs)
+ return new_obs
+
+ @no_type_check
+ def decorator_fn(net_class):
+
+ class new_net_class(net_class):
+
+ def forward(
+ self,
+ obs: Union[np.ndarray, torch.Tensor],
+ *args,
+ **kwargs,
+ ) -> Any:
+ return super().forward(preprocess_obs(obs), *args, **kwargs)
+
+ return new_net_class
+
+ return decorator_fn, new_state_shape
diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py
new file mode 100644
index 0000000..fb75e33
--- /dev/null
+++ b/tianshou/utils/net/continuous.py
@@ -0,0 +1,480 @@
+from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union
+
+import numpy as np
+import torch
+from torch import nn
+
+from tianshou.utils.net.common import MLP
+
+SIGMA_MIN = -20
+SIGMA_MAX = 2
+
+
+class Actor(nn.Module):
+ """Simple actor network. Will create an actor operated in continuous \
+ action space with structure of preprocess_net ---> action_shape.
+
+ :param preprocess_net: a self-defined preprocess_net which output a
+ flattened hidden state.
+ :param action_shape: a sequence of int for the shape of action.
+ :param hidden_sizes: a sequence of int for constructing the MLP after
+ preprocess_net. Default to empty sequence (where the MLP now contains
+ only a single linear layer).
+ :param float max_action: the scale for the final action logits. Default to
+ 1.
+ :param int preprocess_net_output_dim: the output dimension of
+ preprocess_net.
+
+ For advanced usage (how to customize the network), please refer to
+ :ref:`build_the_network`.
+
+ .. seealso::
+
+ Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
+ of how preprocess_net is suggested to be defined.
+ """
+
+ def __init__(
+ self,
+ preprocess_net: nn.Module,
+ action_shape: Sequence[int],
+ hidden_sizes: Sequence[int] = (),
+ max_action: float = 1.0,
+ device: Union[str, int, torch.device] = "cpu",
+ preprocess_net_output_dim: Optional[int] = None,
+ ) -> None:
+ super().__init__()
+ self.device = device
+ self.preprocess = preprocess_net
+ self.output_dim = int(np.prod(action_shape))
+ input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
+ self.last = MLP(
+ input_dim, # type: ignore
+ self.output_dim,
+ hidden_sizes,
+ device=self.device
+ )
+ self._max = max_action
+
+ def forward(
+ self,
+ obs: Union[np.ndarray, torch.Tensor],
+ state: Any = None,
+ info: Dict[str, Any] = {},
+ ) -> Tuple[torch.Tensor, Any]:
+ """Mapping: obs -> logits -> action."""
+ logits, hidden = self.preprocess(obs, state)
+ logits = self._max * torch.tanh(self.last(logits))
+ return logits, hidden
+
+
+class Critic(nn.Module):
+ """Simple critic network. Will create an actor operated in continuous \
+ action space with structure of preprocess_net ---> 1(q value).
+
+ :param preprocess_net: a self-defined preprocess_net which output a
+ flattened hidden state.
+ :param hidden_sizes: a sequence of int for constructing the MLP after
+ preprocess_net. Default to empty sequence (where the MLP now contains
+ only a single linear layer).
+ :param int preprocess_net_output_dim: the output dimension of
+ preprocess_net.
+ :param linear_layer: use this module as linear layer. Default to nn.Linear.
+ :param bool flatten_input: whether to flatten input data for the last layer.
+ Default to True.
+
+ For advanced usage (how to customize the network), please refer to
+ :ref:`build_the_network`.
+
+ .. seealso::
+
+ Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
+ of how preprocess_net is suggested to be defined.
+ """
+
+ def __init__(
+ self,
+ preprocess_net: nn.Module,
+ hidden_sizes: Sequence[int] = (),
+ device: Union[str, int, torch.device] = "cpu",
+ preprocess_net_output_dim: Optional[int] = None,
+ linear_layer: Type[nn.Linear] = nn.Linear,
+ flatten_input: bool = True,
+ ) -> None:
+ super().__init__()
+ self.device = device
+ self.preprocess = preprocess_net
+ self.output_dim = 1
+ input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
+ self.last = MLP(
+ input_dim, # type: ignore
+ 1,
+ hidden_sizes,
+ device=self.device,
+ linear_layer=linear_layer,
+ flatten_input=flatten_input,
+ )
+
+ def forward(
+ self,
+ obs: Union[np.ndarray, torch.Tensor],
+ act: Optional[Union[np.ndarray, torch.Tensor]] = None,
+ info: Dict[str, Any] = {},
+ ) -> torch.Tensor:
+ """Mapping: (s, a) -> logits -> Q(s, a)."""
+ obs = torch.as_tensor(
+ obs,
+ device=self.device,
+ dtype=torch.float32,
+ ).flatten(1)
+ if act is not None:
+ act = torch.as_tensor(
+ act,
+ device=self.device,
+ dtype=torch.float32,
+ ).flatten(1)
+ obs = torch.cat([obs, act], dim=1)
+ logits, hidden = self.preprocess(obs)
+ logits = self.last(logits)
+ return logits
+
+
+class ActorProb(nn.Module):
+ """Simple actor network (output with a Gauss distribution).
+
+ :param preprocess_net: a self-defined preprocess_net which output a
+ flattened hidden state.
+ :param action_shape: a sequence of int for the shape of action.
+ :param hidden_sizes: a sequence of int for constructing the MLP after
+ preprocess_net. Default to empty sequence (where the MLP now contains
+ only a single linear layer).
+ :param float max_action: the scale for the final action logits. Default to
+ 1.
+ :param bool unbounded: whether to apply tanh activation on final logits.
+ Default to False.
+ :param bool conditioned_sigma: True when sigma is calculated from the
+ input, False when sigma is an independent parameter. Default to False.
+ :param int preprocess_net_output_dim: the output dimension of
+ preprocess_net.
+
+ For advanced usage (how to customize the network), please refer to
+ :ref:`build_the_network`.
+
+ .. seealso::
+
+ Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
+ of how preprocess_net is suggested to be defined.
+ """
+
+ def __init__(
+ self,
+ preprocess_net: nn.Module,
+ action_shape: Sequence[int],
+ hidden_sizes: Sequence[int] = (),
+ max_action: float = 1.0,
+ device: Union[str, int, torch.device] = "cpu",
+ unbounded: bool = False,
+ conditioned_sigma: bool = False,
+ preprocess_net_output_dim: Optional[int] = None,
+ ) -> None:
+ super().__init__()
+ self.preprocess = preprocess_net
+ self.device = device
+ self.output_dim = int(np.prod(action_shape))
+ input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
+ self.mu = MLP(
+ input_dim, # type: ignore
+ self.output_dim,
+ hidden_sizes,
+ device=self.device
+ )
+ self._c_sigma = conditioned_sigma
+ if conditioned_sigma:
+ self.sigma = MLP(
+ input_dim, # type: ignore
+ self.output_dim,
+ hidden_sizes,
+ device=self.device
+ )
+ else:
+ self.sigma_param = nn.Parameter(torch.zeros(self.output_dim, 1))
+ self._max = max_action
+ self._unbounded = unbounded
+
+ def forward(
+ self,
+ obs: Union[np.ndarray, torch.Tensor],
+ state: Any = None,
+ info: Dict[str, Any] = {},
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Any]:
+ """Mapping: obs -> logits -> (mu, sigma)."""
+ logits, hidden = self.preprocess(obs, state)
+ mu = self.mu(logits)
+ if not self._unbounded:
+ mu = self._max * torch.tanh(mu)
+ if self._c_sigma:
+ sigma = torch.clamp(self.sigma(logits), min=SIGMA_MIN, max=SIGMA_MAX).exp()
+ else:
+ shape = [1] * len(mu.shape)
+ shape[1] = -1
+ sigma = (self.sigma_param.view(shape) + torch.zeros_like(mu)).exp()
+ return (mu, sigma), state
+
+
+class RecurrentActorProb(nn.Module):
+ """Recurrent version of ActorProb.
+
+ For advanced usage (how to customize the network), please refer to
+ :ref:`build_the_network`.
+ """
+
+ def __init__(
+ self,
+ layer_num: int,
+ state_shape: Sequence[int],
+ action_shape: Sequence[int],
+ hidden_layer_size: int = 128,
+ max_action: float = 1.0,
+ device: Union[str, int, torch.device] = "cpu",
+ unbounded: bool = False,
+ conditioned_sigma: bool = False,
+ ) -> None:
+ super().__init__()
+ self.device = device
+ self.nn = nn.LSTM(
+ input_size=int(np.prod(state_shape)),
+ hidden_size=hidden_layer_size,
+ num_layers=layer_num,
+ batch_first=True,
+ )
+ output_dim = int(np.prod(action_shape))
+ self.mu = nn.Linear(hidden_layer_size, output_dim)
+ self._c_sigma = conditioned_sigma
+ if conditioned_sigma:
+ self.sigma = nn.Linear(hidden_layer_size, output_dim)
+ else:
+ self.sigma_param = nn.Parameter(torch.zeros(output_dim, 1))
+ self._max = max_action
+ self._unbounded = unbounded
+
+ def forward(
+ self,
+ obs: Union[np.ndarray, torch.Tensor],
+ state: Optional[Dict[str, torch.Tensor]] = None,
+ info: Dict[str, Any] = {},
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Dict[str, torch.Tensor]]:
+ """Almost the same as :class:`~tianshou.utils.net.common.Recurrent`."""
+ obs = torch.as_tensor(
+ obs,
+ device=self.device,
+ dtype=torch.float32,
+ )
+ # obs [bsz, len, dim] (training) or [bsz, dim] (evaluation)
+ # In short, the tensor's shape in training phase is longer than which
+ # in evaluation phase.
+ if len(obs.shape) == 2:
+ obs = obs.unsqueeze(-2)
+ self.nn.flatten_parameters()
+ if state is None:
+ obs, (hidden, cell) = self.nn(obs)
+ else:
+ # we store the stack data in [bsz, len, ...] format
+ # but pytorch rnn needs [len, bsz, ...]
+ obs, (hidden, cell) = self.nn(
+ obs, (
+ state["hidden"].transpose(0, 1).contiguous(),
+ state["cell"].transpose(0, 1).contiguous()
+ )
+ )
+ logits = obs[:, -1]
+ mu = self.mu(logits)
+ if not self._unbounded:
+ mu = self._max * torch.tanh(mu)
+ if self._c_sigma:
+ sigma = torch.clamp(self.sigma(logits), min=SIGMA_MIN, max=SIGMA_MAX).exp()
+ else:
+ shape = [1] * len(mu.shape)
+ shape[1] = -1
+ sigma = (self.sigma_param.view(shape) + torch.zeros_like(mu)).exp()
+ # please ensure the first dim is batch size: [bsz, len, ...]
+ return (mu, sigma), {
+ "hidden": hidden.transpose(0, 1).detach(),
+ "cell": cell.transpose(0, 1).detach()
+ }
+
+
+class RecurrentCritic(nn.Module):
+ """Recurrent version of Critic.
+
+ For advanced usage (how to customize the network), please refer to
+ :ref:`build_the_network`.
+ """
+
+ def __init__(
+ self,
+ layer_num: int,
+ state_shape: Sequence[int],
+ action_shape: Sequence[int] = [0],
+ device: Union[str, int, torch.device] = "cpu",
+ hidden_layer_size: int = 128,
+ ) -> None:
+ super().__init__()
+ self.state_shape = state_shape
+ self.action_shape = action_shape
+ self.device = device
+ self.nn = nn.LSTM(
+ input_size=int(np.prod(state_shape)),
+ hidden_size=hidden_layer_size,
+ num_layers=layer_num,
+ batch_first=True,
+ )
+ self.fc2 = nn.Linear(hidden_layer_size + int(np.prod(action_shape)), 1)
+
+ def forward(
+ self,
+ obs: Union[np.ndarray, torch.Tensor],
+ act: Optional[Union[np.ndarray, torch.Tensor]] = None,
+ info: Dict[str, Any] = {},
+ ) -> torch.Tensor:
+ """Almost the same as :class:`~tianshou.utils.net.common.Recurrent`."""
+ obs = torch.as_tensor(
+ obs,
+ device=self.device,
+ dtype=torch.float32,
+ )
+ # obs [bsz, len, dim] (training) or [bsz, dim] (evaluation)
+ # In short, the tensor's shape in training phase is longer than which
+ # in evaluation phase.
+ assert len(obs.shape) == 3
+ self.nn.flatten_parameters()
+ obs, (hidden, cell) = self.nn(obs)
+ obs = obs[:, -1]
+ if act is not None:
+ act = torch.as_tensor(
+ act,
+ device=self.device,
+ dtype=torch.float32,
+ )
+ obs = torch.cat([obs, act], dim=1)
+ obs = self.fc2(obs)
+ return obs
+
+
+class Perturbation(nn.Module):
+ """Implementation of perturbation network in BCQ algorithm. Given a state and \
+ action, it can generate perturbed action.
+
+ :param torch.nn.Module preprocess_net: a self-defined preprocess_net which output a
+ flattened hidden state.
+ :param float max_action: the maximum value of each dimension of action.
+ :param Union[str, int, torch.device] device: which device to create this model on.
+ Default to cpu.
+ :param float phi: max perturbation parameter for BCQ. Default to 0.05.
+
+ For advanced usage (how to customize the network), please refer to
+ :ref:`build_the_network`.
+
+ .. seealso::
+
+ You can refer to `examples/offline/offline_bcq.py` to see how to use it.
+ """
+
+ def __init__(
+ self,
+ preprocess_net: nn.Module,
+ max_action: float,
+ device: Union[str, int, torch.device] = "cpu",
+ phi: float = 0.05
+ ):
+ # preprocess_net: input_dim=state_dim+action_dim, output_dim=action_dim
+ super(Perturbation, self).__init__()
+ self.preprocess_net = preprocess_net
+ self.device = device
+ self.max_action = max_action
+ self.phi = phi
+
+ def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
+ # preprocess_net
+ logits = self.preprocess_net(torch.cat([state, action], -1))[0]
+ noise = self.phi * self.max_action * torch.tanh(logits)
+ # clip to [-max_action, max_action]
+ return (noise + action).clamp(-self.max_action, self.max_action)
+
+
+class VAE(nn.Module):
+ """Implementation of VAE. It models the distribution of action. Given a \
+ state, it can generate actions similar to those in batch. It is used \
+ in BCQ algorithm.
+
+ :param torch.nn.Module encoder: the encoder in VAE. Its input_dim must be
+ state_dim + action_dim, and output_dim must be hidden_dim.
+ :param torch.nn.Module decoder: the decoder in VAE. Its input_dim must be
+ state_dim + latent_dim, and output_dim must be action_dim.
+ :param int hidden_dim: the size of the last linear-layer in encoder.
+ :param int latent_dim: the size of latent layer.
+ :param float max_action: the maximum value of each dimension of action.
+ :param Union[str, torch.device] device: which device to create this model on.
+ Default to "cpu".
+
+ For advanced usage (how to customize the network), please refer to
+ :ref:`build_the_network`.
+
+ .. seealso::
+
+ You can refer to `examples/offline/offline_bcq.py` to see how to use it.
+ """
+
+ def __init__(
+ self,
+ encoder: nn.Module,
+ decoder: nn.Module,
+ hidden_dim: int,
+ latent_dim: int,
+ max_action: float,
+ device: Union[str, torch.device] = "cpu"
+ ):
+ super(VAE, self).__init__()
+ self.encoder = encoder
+
+ self.mean = nn.Linear(hidden_dim, latent_dim)
+ self.log_std = nn.Linear(hidden_dim, latent_dim)
+
+ self.decoder = decoder
+
+ self.max_action = max_action
+ self.latent_dim = latent_dim
+ self.device = device
+
+ def forward(
+ self, state: torch.Tensor, action: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ # [state, action] -> z , [state, z] -> action
+ latent_z = self.encoder(torch.cat([state, action], -1))
+ # shape of z: (state.shape[:-1], hidden_dim)
+
+ mean = self.mean(latent_z)
+ # Clamped for numerical stability
+ log_std = self.log_std(latent_z).clamp(-4, 15)
+ std = torch.exp(log_std)
+ # shape of mean, std: (state.shape[:-1], latent_dim)
+
+ latent_z = mean + std * torch.randn_like(std) # (state.shape[:-1], latent_dim)
+
+ reconstruction = self.decode(state, latent_z) # (state.shape[:-1], action_dim)
+ return reconstruction, mean, std
+
+ def decode(
+ self,
+ state: torch.Tensor,
+ latent_z: Union[torch.Tensor, None] = None
+ ) -> torch.Tensor:
+ # decode(state) -> action
+ if latent_z is None:
+ # state.shape[0] may be batch_size
+ # latent vector clipped to [-0.5, 0.5]
+ latent_z = torch.randn(state.shape[:-1] + (self.latent_dim, )) \
+ .to(self.device).clamp(-0.5, 0.5)
+
+ # decode z with state!
+ return self.max_action * \
+ torch.tanh(self.decoder(torch.cat([state, latent_z], -1)))
diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py
new file mode 100644
index 0000000..b976585
--- /dev/null
+++ b/tianshou/utils/net/discrete.py
@@ -0,0 +1,449 @@
+from typing import Any, Dict, Optional, Sequence, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from tianshou.data import Batch, to_torch
+from tianshou.utils.net.common import MLP
+
+
+class Actor(nn.Module):
+ """Simple actor network.
+
+ Will create an actor operated in discrete action space with structure of
+ preprocess_net ---> action_shape.
+
+ :param preprocess_net: a self-defined preprocess_net which output a
+ flattened hidden state.
+ :param action_shape: a sequence of int for the shape of action.
+ :param hidden_sizes: a sequence of int for constructing the MLP after
+ preprocess_net. Default to empty sequence (where the MLP now contains
+ only a single linear layer).
+ :param bool softmax_output: whether to apply a softmax layer over the last
+ layer's output.
+ :param int preprocess_net_output_dim: the output dimension of
+ preprocess_net.
+
+ For advanced usage (how to customize the network), please refer to
+ :ref:`build_the_network`.
+
+ .. seealso::
+
+ Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
+ of how preprocess_net is suggested to be defined.
+ """
+
+ def __init__(
+ self,
+ preprocess_net: nn.Module,
+ action_shape: Sequence[int],
+ hidden_sizes: Sequence[int] = (),
+ softmax_output: bool = True,
+ preprocess_net_output_dim: Optional[int] = None,
+ device: Union[str, int, torch.device] = "cpu",
+ ) -> None:
+ super().__init__()
+ self.device = device
+ self.preprocess = preprocess_net
+ self.output_dim = int(np.prod(action_shape))
+ input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
+ self.last = MLP(
+ input_dim, # type: ignore
+ self.output_dim,
+ hidden_sizes,
+ device=self.device
+ )
+ self.softmax_output = softmax_output
+
+ def forward(
+ self,
+ obs: Union[np.ndarray, torch.Tensor],
+ state: Any = None,
+ info: Dict[str, Any] = {},
+ ) -> Tuple[torch.Tensor, Any]:
+ r"""Mapping: s -> Q(s, \*)."""
+ logits, hidden = self.preprocess(obs, state)
+ logits = self.last(logits)
+ if self.softmax_output:
+ logits = F.softmax(logits, dim=-1)
+ return logits, hidden
+
+
+class Critic(nn.Module):
+ """Simple critic network. Will create an actor operated in discrete \
+ action space with structure of preprocess_net ---> 1(q value).
+
+ :param preprocess_net: a self-defined preprocess_net which output a
+ flattened hidden state.
+ :param hidden_sizes: a sequence of int for constructing the MLP after
+ preprocess_net. Default to empty sequence (where the MLP now contains
+ only a single linear layer).
+ :param int last_size: the output dimension of Critic network. Default to 1.
+ :param int preprocess_net_output_dim: the output dimension of
+ preprocess_net.
+
+ For advanced usage (how to customize the network), please refer to
+ :ref:`build_the_network`.
+
+ .. seealso::
+
+ Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
+ of how preprocess_net is suggested to be defined.
+ """
+
+ def __init__(
+ self,
+ preprocess_net: nn.Module,
+ hidden_sizes: Sequence[int] = (),
+ last_size: int = 1,
+ preprocess_net_output_dim: Optional[int] = None,
+ device: Union[str, int, torch.device] = "cpu",
+ ) -> None:
+ super().__init__()
+ self.device = device
+ self.preprocess = preprocess_net
+ self.output_dim = last_size
+ input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
+ self.last = MLP(
+ input_dim, # type: ignore
+ last_size,
+ hidden_sizes,
+ device=self.device
+ )
+
+ def forward(
+ self, obs: Union[np.ndarray, torch.Tensor], **kwargs: Any
+ ) -> torch.Tensor:
+ """Mapping: s -> V(s)."""
+ logits, _ = self.preprocess(obs, state=kwargs.get("state", None))
+ return self.last(logits)
+
+
+class CosineEmbeddingNetwork(nn.Module):
+ """Cosine embedding network for IQN. Convert a scalar in [0, 1] to a list \
+ of n-dim vectors.
+
+ :param num_cosines: the number of cosines used for the embedding.
+ :param embedding_dim: the dimension of the embedding/output.
+
+ .. note::
+
+ From https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/blob/master
+ /fqf_iqn_qrdqn/network.py .
+ """
+
+ def __init__(self, num_cosines: int, embedding_dim: int) -> None:
+ super().__init__()
+ self.net = nn.Sequential(nn.Linear(num_cosines, embedding_dim), nn.ReLU())
+ self.num_cosines = num_cosines
+ self.embedding_dim = embedding_dim
+
+ def forward(self, taus: torch.Tensor) -> torch.Tensor:
+ batch_size = taus.shape[0]
+ N = taus.shape[1]
+ # Calculate i * \pi (i=1,...,N).
+ i_pi = np.pi * torch.arange(
+ start=1, end=self.num_cosines + 1, dtype=taus.dtype, device=taus.device
+ ).view(1, 1, self.num_cosines)
+ # Calculate cos(i * \pi * \tau).
+ cosines = torch.cos(taus.view(batch_size, N, 1) * i_pi
+ ).view(batch_size * N, self.num_cosines)
+ # Calculate embeddings of taus.
+ tau_embeddings = self.net(cosines).view(batch_size, N, self.embedding_dim)
+ return tau_embeddings
+
+
+class ImplicitQuantileNetwork(Critic):
+ """Implicit Quantile Network.
+
+ :param preprocess_net: a self-defined preprocess_net which output a
+ flattened hidden state.
+ :param int action_shape: a sequence of int for the shape of action.
+ :param hidden_sizes: a sequence of int for constructing the MLP after
+ preprocess_net. Default to empty sequence (where the MLP now contains
+ only a single linear layer).
+ :param int num_cosines: the number of cosines to use for cosine embedding.
+ Default to 64.
+ :param int preprocess_net_output_dim: the output dimension of
+ preprocess_net.
+
+ .. note::
+
+ Although this class inherits Critic, it is actually a quantile Q-Network
+ with output shape (batch_size, action_dim, sample_size).
+
+ The second item of the first return value is tau vector.
+ """
+
+ def __init__(
+ self,
+ preprocess_net: nn.Module,
+ action_shape: Sequence[int],
+ hidden_sizes: Sequence[int] = (),
+ num_cosines: int = 64,
+ preprocess_net_output_dim: Optional[int] = None,
+ device: Union[str, int, torch.device] = "cpu"
+ ) -> None:
+ last_size = int(np.prod(action_shape))
+ super().__init__(
+ preprocess_net, hidden_sizes, last_size, preprocess_net_output_dim, device
+ )
+ self.input_dim = getattr(
+ preprocess_net, "output_dim", preprocess_net_output_dim
+ )
+ self.embed_model = CosineEmbeddingNetwork(
+ num_cosines,
+ self.input_dim # type: ignore
+ ).to(device)
+
+ def forward( # type: ignore
+ self, obs: Union[np.ndarray, torch.Tensor], sample_size: int, **kwargs: Any
+ ) -> Tuple[Any, torch.Tensor]:
+ r"""Mapping: s -> Q(s, \*)."""
+ logits, hidden = self.preprocess(obs, state=kwargs.get("state", None))
+ # Sample fractions.
+ batch_size = logits.size(0)
+ taus = torch.rand(
+ batch_size, sample_size, dtype=logits.dtype, device=logits.device
+ )
+ embedding = (logits.unsqueeze(1) *
+ self.embed_model(taus)).view(batch_size * sample_size, -1)
+ out = self.last(embedding).view(batch_size, sample_size, -1).transpose(1, 2)
+ return (out, taus), hidden
+
+
+class FractionProposalNetwork(nn.Module):
+ """Fraction proposal network for FQF.
+
+ :param num_fractions: the number of factions to propose.
+ :param embedding_dim: the dimension of the embedding/input.
+
+ .. note::
+
+ Adapted from https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/blob/master
+ /fqf_iqn_qrdqn/network.py .
+ """
+
+ def __init__(self, num_fractions: int, embedding_dim: int) -> None:
+ super().__init__()
+ self.net = nn.Linear(embedding_dim, num_fractions)
+ torch.nn.init.xavier_uniform_(self.net.weight, gain=0.01)
+ torch.nn.init.constant_(self.net.bias, 0)
+ self.num_fractions = num_fractions
+ self.embedding_dim = embedding_dim
+
+ def forward(
+ self, obs_embeddings: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ # Calculate (log of) probabilities q_i in the paper.
+ dist = torch.distributions.Categorical(logits=self.net(obs_embeddings))
+ taus_1_N = torch.cumsum(dist.probs, dim=1)
+ # Calculate \tau_i (i=0,...,N).
+ taus = F.pad(taus_1_N, (1, 0))
+ # Calculate \hat \tau_i (i=0,...,N-1).
+ tau_hats = (taus[:, :-1] + taus[:, 1:]).detach() / 2.0
+ # Calculate entropies of value distributions.
+ entropies = dist.entropy()
+ return taus, tau_hats, entropies
+
+
+class FullQuantileFunction(ImplicitQuantileNetwork):
+ """Full(y parameterized) Quantile Function.
+
+ :param preprocess_net: a self-defined preprocess_net which output a
+ flattened hidden state.
+ :param int action_shape: a sequence of int for the shape of action.
+ :param hidden_sizes: a sequence of int for constructing the MLP after
+ preprocess_net. Default to empty sequence (where the MLP now contains
+ only a single linear layer).
+ :param int num_cosines: the number of cosines to use for cosine embedding.
+ Default to 64.
+ :param int preprocess_net_output_dim: the output dimension of
+ preprocess_net.
+
+ .. note::
+
+ The first return value is a tuple of (quantiles, fractions, quantiles_tau),
+ where fractions is a Batch(taus, tau_hats, entropies).
+ """
+
+ def __init__(
+ self,
+ preprocess_net: nn.Module,
+ action_shape: Sequence[int],
+ hidden_sizes: Sequence[int] = (),
+ num_cosines: int = 64,
+ preprocess_net_output_dim: Optional[int] = None,
+ device: Union[str, int, torch.device] = "cpu",
+ ) -> None:
+ super().__init__(
+ preprocess_net, action_shape, hidden_sizes, num_cosines,
+ preprocess_net_output_dim, device
+ )
+
+ def _compute_quantiles(
+ self, obs: torch.Tensor, taus: torch.Tensor
+ ) -> torch.Tensor:
+ batch_size, sample_size = taus.shape
+ embedding = (obs.unsqueeze(1) *
+ self.embed_model(taus)).view(batch_size * sample_size, -1)
+ quantiles = self.last(embedding).view(batch_size, sample_size,
+ -1).transpose(1, 2)
+ return quantiles
+
+ def forward( # type: ignore
+ self, obs: Union[np.ndarray, torch.Tensor],
+ propose_model: FractionProposalNetwork,
+ fractions: Optional[Batch] = None,
+ **kwargs: Any
+ ) -> Tuple[Any, torch.Tensor]:
+ r"""Mapping: s -> Q(s, \*)."""
+ logits, hidden = self.preprocess(obs, state=kwargs.get("state", None))
+ # Propose fractions
+ if fractions is None:
+ taus, tau_hats, entropies = propose_model(logits.detach())
+ fractions = Batch(taus=taus, tau_hats=tau_hats, entropies=entropies)
+ else:
+ taus, tau_hats = fractions.taus, fractions.tau_hats
+ quantiles = self._compute_quantiles(logits, tau_hats)
+ # Calculate quantiles_tau for computing fraction grad
+ quantiles_tau = None
+ if self.training:
+ with torch.no_grad():
+ quantiles_tau = self._compute_quantiles(logits, taus[:, 1:-1])
+ return (quantiles, fractions, quantiles_tau), hidden
+
+
+class NoisyLinear(nn.Module):
+ """Implementation of Noisy Networks. arXiv:1706.10295.
+
+ :param int in_features: the number of input features.
+ :param int out_features: the number of output features.
+ :param float noisy_std: initial standard deviation of noisy linear layers.
+
+ .. note::
+
+ Adapted from https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/blob/master
+ /fqf_iqn_qrdqn/network.py .
+ """
+
+ def __init__(
+ self, in_features: int, out_features: int, noisy_std: float = 0.5
+ ) -> None:
+ super().__init__()
+
+ # Learnable parameters.
+ self.mu_W = nn.Parameter(torch.FloatTensor(out_features, in_features))
+ self.sigma_W = nn.Parameter(torch.FloatTensor(out_features, in_features))
+ self.mu_bias = nn.Parameter(torch.FloatTensor(out_features))
+ self.sigma_bias = nn.Parameter(torch.FloatTensor(out_features))
+
+ # Factorized noise parameters.
+ self.register_buffer('eps_p', torch.FloatTensor(in_features))
+ self.register_buffer('eps_q', torch.FloatTensor(out_features))
+
+ self.in_features = in_features
+ self.out_features = out_features
+ self.sigma = noisy_std
+
+ self.reset()
+ self.sample()
+
+ def reset(self) -> None:
+ bound = 1 / np.sqrt(self.in_features)
+ self.mu_W.data.uniform_(-bound, bound)
+ self.mu_bias.data.uniform_(-bound, bound)
+ self.sigma_W.data.fill_(self.sigma / np.sqrt(self.in_features))
+ self.sigma_bias.data.fill_(self.sigma / np.sqrt(self.in_features))
+
+ def f(self, x: torch.Tensor) -> torch.Tensor:
+ x = torch.randn(x.size(0), device=x.device)
+ return x.sign().mul_(x.abs().sqrt_())
+
+ def sample(self) -> None:
+ self.eps_p.copy_(self.f(self.eps_p)) # type: ignore
+ self.eps_q.copy_(self.f(self.eps_q)) # type: ignore
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ if self.training:
+ weight = self.mu_W + self.sigma_W * (
+ self.eps_q.ger(self.eps_p) # type: ignore
+ )
+ bias = self.mu_bias + self.sigma_bias * self.eps_q.clone() # type: ignore
+ else:
+ weight = self.mu_W
+ bias = self.mu_bias
+
+ return F.linear(x, weight, bias)
+
+
+def sample_noise(model: nn.Module) -> bool:
+ """Sample the random noises of NoisyLinear modules in the model.
+
+ :param model: a PyTorch module which may have NoisyLinear submodules.
+ :returns: True if model has at least one NoisyLinear submodule;
+ otherwise, False.
+ """
+ done = False
+ for m in model.modules():
+ if isinstance(m, NoisyLinear):
+ m.sample()
+ done = True
+ return done
+
+
+class IntrinsicCuriosityModule(nn.Module):
+ """Implementation of Intrinsic Curiosity Module. arXiv:1705.05363.
+
+ :param torch.nn.Module feature_net: a self-defined feature_net which output a
+ flattened hidden state.
+ :param int feature_dim: input dimension of the feature net.
+ :param int action_dim: dimension of the action space.
+ :param hidden_sizes: hidden layer sizes for forward and inverse models.
+ :param device: device for the module.
+ """
+
+ def __init__(
+ self,
+ feature_net: nn.Module,
+ feature_dim: int,
+ action_dim: int,
+ hidden_sizes: Sequence[int] = (),
+ device: Union[str, torch.device] = "cpu"
+ ) -> None:
+ super().__init__()
+ self.feature_net = feature_net
+ self.forward_model = MLP(
+ feature_dim + action_dim,
+ output_dim=feature_dim,
+ hidden_sizes=hidden_sizes,
+ device=device
+ )
+ self.inverse_model = MLP(
+ feature_dim * 2,
+ output_dim=action_dim,
+ hidden_sizes=hidden_sizes,
+ device=device
+ )
+ self.feature_dim = feature_dim
+ self.action_dim = action_dim
+ self.device = device
+
+ def forward(
+ self, s1: Union[np.ndarray, torch.Tensor],
+ act: Union[np.ndarray, torch.Tensor], s2: Union[np.ndarray,
+ torch.Tensor], **kwargs: Any
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ r"""Mapping: s1, act, s2 -> mse_loss, act_hat."""
+ s1 = to_torch(s1, dtype=torch.float32, device=self.device)
+ s2 = to_torch(s2, dtype=torch.float32, device=self.device)
+ phi1, phi2 = self.feature_net(s1), self.feature_net(s2)
+ act = to_torch(act, dtype=torch.long, device=self.device)
+ phi2_hat = self.forward_model(
+ torch.cat([phi1, F.one_hot(act, num_classes=self.action_dim)], dim=1)
+ )
+ mse_loss = 0.5 * F.mse_loss(phi2_hat, phi2, reduction="none").sum(1)
+ act_hat = self.inverse_model(torch.cat([phi1, phi2], dim=1))
+ return mse_loss, act_hat
diff --git a/tianshou/utils/progress_bar.py b/tianshou/utils/progress_bar.py
new file mode 100644
index 0000000..dc3cd03
--- /dev/null
+++ b/tianshou/utils/progress_bar.py
@@ -0,0 +1,35 @@
+from typing import Any
+
+tqdm_config = {
+ "dynamic_ncols": True,
+ "ascii": True,
+}
+
+
+class DummyTqdm:
+ """A dummy tqdm class that keeps stats but without progress bar.
+
+ It supports ``__enter__`` and ``__exit__``, update and a dummy
+ ``set_postfix``, which is the interface that trainers use.
+
+ .. note::
+
+ Using ``disable=True`` in tqdm config results in infinite loop, thus
+ this class is created. See the discussion at #641 for details.
+ """
+
+ def __init__(self, total: int, **kwargs: Any):
+ self.total = total
+ self.n = 0
+
+ def set_postfix(self, **kwargs: Any) -> None:
+ pass
+
+ def update(self, n: int = 1) -> None:
+ self.n += n
+
+ def __enter__(self) -> "DummyTqdm":
+ return self
+
+ def __exit__(self, *args: Any, **kwargs: Any) -> None:
+ pass
diff --git a/tianshou/utils/statistics.py b/tianshou/utils/statistics.py
new file mode 100644
index 0000000..505bf79
--- /dev/null
+++ b/tianshou/utils/statistics.py
@@ -0,0 +1,114 @@
+from numbers import Number
+from typing import List, Optional, Union
+
+import numpy as np
+import torch
+
+
+class MovAvg(object):
+ """Class for moving average.
+
+ It will automatically exclude the infinity and NaN. Usage:
+ ::
+
+ >>> stat = MovAvg(size=66)
+ >>> stat.add(torch.tensor(5))
+ 5.0
+ >>> stat.add(float('inf')) # which will not add to stat
+ 5.0
+ >>> stat.add([6, 7, 8])
+ 6.5
+ >>> stat.get()
+ 6.5
+ >>> print(f'{stat.mean():.2f}±{stat.std():.2f}')
+ 6.50±1.12
+ """
+
+ def __init__(self, size: int = 100) -> None:
+ super().__init__()
+ self.size = size
+ self.cache: List[np.number] = []
+ self.banned = [np.inf, np.nan, -np.inf]
+
+ def add(
+ self, data_array: Union[Number, np.number, list, np.ndarray, torch.Tensor]
+ ) -> float:
+ """Add a scalar into :class:`MovAvg`.
+
+ You can add ``torch.Tensor`` with only one element, a python scalar, or
+ a list of python scalar.
+ """
+ if isinstance(data_array, torch.Tensor):
+ data_array = data_array.flatten().cpu().numpy()
+ if np.isscalar(data_array):
+ data_array = [data_array]
+ for number in data_array: # type: ignore
+ if number not in self.banned:
+ self.cache.append(number)
+ if self.size > 0 and len(self.cache) > self.size:
+ self.cache = self.cache[-self.size:]
+ return self.get()
+
+ def get(self) -> float:
+ """Get the average."""
+ if len(self.cache) == 0:
+ return 0.0
+ return float(np.mean(self.cache)) # type: ignore
+
+ def mean(self) -> float:
+ """Get the average. Same as :meth:`get`."""
+ return self.get()
+
+ def std(self) -> float:
+ """Get the standard deviation."""
+ if len(self.cache) == 0:
+ return 0.0
+ return float(np.std(self.cache)) # type: ignore
+
+
+class RunningMeanStd(object):
+ """Calculates the running mean and std of a data stream.
+
+ https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
+
+ :param mean: the initial mean estimation for data array. Default to 0.
+ :param std: the initial standard error estimation for data array. Default to 1.
+ :param float clip_max: the maximum absolute value for data array. Default to
+ 10.0.
+ :param float epsilon: To avoid division by zero.
+ """
+
+ def __init__(
+ self,
+ mean: Union[float, np.ndarray] = 0.0,
+ std: Union[float, np.ndarray] = 1.0,
+ clip_max: Optional[float] = 10.0,
+ epsilon: float = np.finfo(np.float32).eps.item(),
+ ) -> None:
+ self.mean, self.var = mean, std
+ self.clip_max = clip_max
+ self.count = 0
+ self.eps = epsilon
+
+ def norm(self, data_array: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
+ data_array = (data_array - self.mean) / np.sqrt(self.var + self.eps)
+ if self.clip_max:
+ data_array = np.clip(data_array, -self.clip_max, self.clip_max)
+ return data_array
+
+ def update(self, data_array: np.ndarray) -> None:
+ """Add a batch of item into RMS with the same shape, modify mean/var/count."""
+ batch_mean, batch_var = np.mean(data_array, axis=0), np.var(data_array, axis=0)
+ batch_count = len(data_array)
+
+ delta = batch_mean - self.mean
+ total_count = self.count + batch_count
+
+ new_mean = self.mean + delta * batch_count / total_count
+ m_a = self.var * self.count
+ m_b = batch_var * batch_count
+ m_2 = m_a + m_b + delta**2 * self.count * batch_count / total_count
+ new_var = m_2 / total_count
+
+ self.mean, self.var = new_mean, new_var
+ self.count = total_count
diff --git a/tianshou/utils/warning.py b/tianshou/utils/warning.py
new file mode 100644
index 0000000..93c5cce
--- /dev/null
+++ b/tianshou/utils/warning.py
@@ -0,0 +1,8 @@
+import warnings
+
+warnings.simplefilter("once", DeprecationWarning)
+
+
+def deprecation(msg: str) -> None:
+ """Deprecation warning wrapper."""
+ warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
diff --git a/tools.py b/tools.py
new file mode 100755
index 0000000..5bbb2b5
--- /dev/null
+++ b/tools.py
@@ -0,0 +1,104 @@
+import os
+import argparse
+import time
+import random
+from shutil import copyfile, copytree
+from distutils.util import strtobool
+
+import numpy as np
+import torch
+import torch.nn as nn
+from gymnasium.envs.registration import register
+
+
+def backup(time_str, args, upper_policy=None):
+ if args.mode == "test":
+ target_dir = os.path.join('./logs/evaluation', time_str)
+ else:
+ target_dir = os.path.join('./logs/experiment', time_str)
+
+ if not os.path.exists(target_dir):
+ os.makedirs(target_dir)
+ # TODO
+ copyfile('config.py', os.path.join(target_dir, 'config.py'))
+
+ # gym_path = './problems'
+ # env_name = args.id.split('-v')
+ # env_name = env_name[0] + env_name[1]
+ # env_path = os.path.join(gym_path, env_name)
+ # copytree(env_path, os.path.join(target_dir, env_name))
+
+ if upper_policy is not None:
+ torch.save(upper_policy.state_dict(),
+ os.path.join(args.model_save_path, time_str, 'upper-first-' + time_str + ".pt"))
+
+
+def registration_envs():
+ register(
+ id='OnlinePack-v1',
+ entry_point='envs.Packing:PackingEnv',
+ )
+
+
+def load_policy(load_path, model, device="cpu"):
+ print(f"load model from: {load_path}")
+ assert os.path.exists(load_path), 'File does not exist'
+ pretrained_state_dict = torch.load(load_path, map_location=device)
+ if len(pretrained_state_dict) == 2:
+ pretrained_state_dict, ob_rms = pretrained_state_dict
+
+ load_dict = {}
+ for k, v in pretrained_state_dict.items():
+ if 'actor.embedder.layers' in k:
+ load_dict[k.replace('module.weight', 'weight')] = v
+ else:
+ load_dict[k.replace('module.', '')] = v
+
+ load_dict = {k.replace('add_bias.', ''): v for k, v in load_dict.items()}
+ load_dict = {k.replace('_bias', 'bias'): v for k, v in load_dict.items()}
+
+ model.load_state_dict(load_dict, strict=True)
+ print('Loading pre-train upper model', load_path)
+ return model
+
+
+def set_seed(seed: int, cuda: bool = False, cuda_deterministic: bool = False):
+ """set random seed"""
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+ os.environ['PYTHONHASHSEED'] = str(seed)
+
+ if cuda and torch.cuda.is_available() and cuda_deterministic:
+ torch.backends.cudnn.enabled = True
+ torch.backends.cudnn.benchmark = False
+ torch.backends.cudnn.deterministic = True
+ # torch.use_deterministic_algorithms(True)
+
+
+class CategoricalMasked(torch.distributions.Categorical):
+ def __init__(self, probs=None, logits=None, validate_args=None, masks=[]):
+ self.device = logits.device
+ self.masks = masks
+ if len(self.masks) == 0:
+ super(CategoricalMasked, self).__init__(probs, logits, validate_args)
+ else:
+ self.masks = masks.type(torch.BoolTensor).to(self.device)
+ logits = torch.where(self.masks, logits, torch.tensor(-1e18).to(self.device))
+ super(CategoricalMasked, self).__init__(probs, logits, validate_args)
+
+ def entropy(self):
+ if len(self.masks) == 0:
+ return super(CategoricalMasked, self).entropy()
+ p_log_p = self.logits * self.probs
+ p_log_p = torch.where(self.masks, p_log_p, torch.tensor(0.).to(self.device))
+ return -p_log_p.sum(-1)
+
+
+if __name__ == '__main__':
+ registration_envs()
+ from gymnasium import envs
+ envids = [spec.id for spec in envs.registry.all()]
+ print(envids)
diff --git a/ts_test.py b/ts_test.py
new file mode 100644
index 0000000..0b06d50
--- /dev/null
+++ b/ts_test.py
@@ -0,0 +1,97 @@
+import os
+import sys
+curr_path = os.path.dirname(os.path.abspath(__file__))
+parent_path = os.path.dirname(curr_path)
+sys.path.append(parent_path)
+
+import random
+
+import gymnasium as gym
+import torch
+from tianshou.utils.net.common import ActorCritic
+
+from ts_train import build_net
+import arguments
+from tools import *
+from mycollector import PackCollector
+from masked_ppo import MaskedPPOPolicy
+
+
+def test(args):
+
+ if args.cuda and torch.cuda.is_available():
+ device = torch.device("cuda", args.device)
+ else:
+ device = torch.device("cpu")
+
+ set_seed(args.seed, args.cuda, args.cuda_deterministic)
+
+ # environment
+ test_env = gym.make(
+ args.env.id,
+ container_size=args.env.container_size,
+ enable_rotation=args.env.rot,
+ data_type=args.env.box_type,
+ item_set=args.env.box_size_set,
+ reward_type=args.train.reward_type,
+ action_scheme=args.env.scheme,
+ k_placement=args.env.k_placement,
+ is_render=args.render
+ )
+
+ # network
+ actor, critic = build_net(args, device)
+ actor_critic = ActorCritic(actor, critic)
+
+ optim = torch.optim.Adam(actor_critic.parameters(), lr=args.opt.lr, eps=args.opt.eps)
+
+ # RL agent
+ dist = CategoricalMasked
+
+ policy = MaskedPPOPolicy(
+ actor=actor,
+ critic=critic,
+ optim=optim,
+ dist_fn=dist,
+ discount_factor=args.train.gamma,
+ eps_clip=args.train.clip_param,
+ advantage_normalization=False,
+ vf_coef=args.loss.value,
+ ent_coef=args.loss.entropy,
+ gae_lambda=args.train.gae_lambda,
+ action_space=test_env.action_space,
+ )
+
+ policy.eval()
+ try:
+ policy.load_state_dict(torch.load(args.ckp, map_location=device))
+ # print(policy)
+ except FileNotFoundError:
+ print("No model found")
+ exit()
+
+ test_collector = PackCollector(policy, test_env)
+
+ # Evaluation
+ result = test_collector.collect(n_episode=args.test_episode, render=args.render)
+ for i in range(args.test_episode):
+ print(f"episode {i+1}\t => \tratio: {result['ratios'][i]:.4f} \t| total: {result['nums'][i]}")
+ print('All cases have been done!')
+ print('----------------------------------------------')
+ print('average space utilization: %.4f'%(result['ratio']))
+ print('average put item number: %.4f'%(result['num']))
+ print("standard variance: %.4f"%(result['ratio_std']))
+
+
+if __name__ == '__main__':
+ registration_envs()
+ args = arguments.get_args()
+ args.train.algo = args.train.algo.upper()
+ args.train.step_per_collect = args.train.num_processes * args.train.num_steps
+
+ if args.render:
+ args.test_episode = 1 # for visualization
+
+ args.seed = 5
+ print(f"dimension: {args.env.container_size}")
+ test(args)
diff --git a/ts_train.py b/ts_train.py
new file mode 100644
index 0000000..c111507
--- /dev/null
+++ b/ts_train.py
@@ -0,0 +1,269 @@
+"""
+refer:
+- https://github.com/albertcity/OCARL
+- https://github.com/pioneer-innovation/Real-3D-Embodied-Dataset
+
+"""
+import sys
+import os
+curr_path = os.path.dirname(os.path.abspath(__file__))
+parent_path = os.path.dirname(curr_path)
+sys.path.append(parent_path)
+import warnings
+warnings.filterwarnings("ignore")
+
+import time
+import pprint
+import shutil
+import random
+
+import numpy as np
+import gymnasium as gym
+import torch
+import torch.nn as nn
+from torch.utils.tensorboard import SummaryWriter
+from torch.optim.lr_scheduler import LambdaLR, ExponentialLR
+import tianshou as ts
+from tianshou.utils import TensorboardLogger, LazyLogger
+from tianshou.data import VectorReplayBuffer
+from tianshou.utils.net.common import ActorCritic, DataParallelNet
+from tianshou.trainer import onpolicy_trainer
+
+import model
+import arguments
+from tools import *
+from masked_ppo import MaskedPPOPolicy
+from masked_a2c import MaskedA2CPolicy
+from mycollector import PackCollector
+
+
+def make_envs(args):
+
+ train_envs = ts.env.SubprocVectorEnv(
+ [lambda: gym.make(args.env.id,
+ container_size=args.env.container_size,
+ enable_rotation=args.env.rot,
+ data_type=args.env.box_type,
+ item_set=args.env.box_size_set,
+ reward_type=args.train.reward_type,
+ action_scheme=args.env.scheme,
+ k_placement=args.env.k_placement)
+ for _ in range(args.train.num_processes)]
+ )
+ test_envs = ts.env.SubprocVectorEnv(
+ [lambda: gym.make(args.env.id,
+ container_size=args.env.container_size,
+ enable_rotation=args.env.rot,
+ data_type=args.env.box_type,
+ item_set=args.env.box_size_set,
+ reward_type=args.train.reward_type,
+ action_scheme=args.env.scheme,
+ k_placement=args.env.k_placement)
+ for _ in range(1)]
+ )
+ train_envs.seed(args.seed)
+ test_envs.seed(args.seed)
+
+ return train_envs, test_envs
+
+
+def build_net(args, device):
+ feature_net = model.ShareNet(
+ k_placement=args.env.k_placement,
+ box_max_size=args.env.box_big,
+ container_size=args.env.container_size,
+ embed_size=args.model.embed_dim,
+ num_layers=args.model.num_layers,
+ forward_expansion=args.model.forward_expansion,
+ heads=args.model.heads,
+ dropout=args.model.dropout,
+ device=device,
+ place_gen=args.env.scheme
+ )
+
+ actor = model.ActorHead(
+ preprocess_net=feature_net,
+ embed_size=args.model.embed_dim,
+ padding_mask=args.model.padding_mask,
+ device=device,
+ ).to(device)
+
+ critic = model.CriticHead(
+ preprocess_net=feature_net,
+ k_placement=args.env.k_placement,
+ embed_size=args.model.embed_dim,
+ padding_mask=args.model.padding_mask,
+ device=device,
+ ).to(device)
+
+ return actor, critic
+
+
+def train(args):
+
+ date = time.strftime(r'%Y.%m.%d-%H-%M-%S', time.localtime(time.time()))
+ time_str = args.env.id + "_" + \
+ str(args.env.container_size[0]) + "-" + str(args.env.container_size[1]) + "-" + str(args.env.container_size[2]) + "_" + \
+ args.env.scheme + "_" + str(args.env.k_placement) + "_" +\
+ args.env.box_type + "_" + \
+ args.train.algo + '_' \
+ 'seed' + str(args.seed) + "_" + \
+ args.opt.optimizer + "_" \
+ + date
+
+ if args.cuda and torch.cuda.is_available():
+ device = torch.device("cuda", args.device)
+ else:
+ device = torch.device("cpu")
+
+ set_seed(args.seed, args.cuda, args.cuda_deterministic)
+
+ # environments
+ train_envs, test_envs = make_envs(args) # make envs and set random seed
+
+ # network
+ actor, critic = build_net(args, device)
+ actor_critic = ActorCritic(actor, critic)
+
+ if args.opt.optimizer == 'Adam':
+ optim = torch.optim.Adam(actor_critic.parameters(), lr=args.opt.lr, eps=args.opt.eps)
+ elif args.opt.optimizer == 'RMSprop':
+ optim = torch.optim.RMSprop(
+ actor_critic.parameters(),
+ lr=args.opt.lr,
+ eps=args.opt.eps,
+ alpha=args.opt.alpha,
+ )
+ else:
+ raise NotImplementedError
+
+ lr_scheduler = None
+ if args.opt.lr_decay:
+ # decay learning rate to 0 linearly
+ max_update_num = np.ceil(args.train.step_per_epoch / args.train.step_per_collect) * args.train.epoch
+ lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
+
+
+ # RL agent
+ dist = CategoricalMasked
+ if args.train.algo == 'PPO':
+ policy = MaskedPPOPolicy(
+ actor=actor,
+ critic=critic,
+ optim=optim,
+ dist_fn=dist,
+ discount_factor=args.train.gamma,
+ eps_clip=args.train.clip_param,
+ advantage_normalization=False,
+ vf_coef=args.loss.value,
+ ent_coef=args.loss.entropy,
+ gae_lambda=args.train.gae_lambda,
+ lr_scheduler=lr_scheduler
+ )
+ elif args.algo == 'A2C':
+ policy = MaskedA2CPolicy(
+ actor,
+ critic,
+ optim,
+ dist,
+ discount_factor=args.train.gamma,
+ vf_coef=args.loss.value,
+ ent_coef=args.loss.entropy,
+ gae_lambda=args.train.gae_lambda,
+ lr_scheduler=lr_scheduler
+ )
+ else:
+ raise NotImplementedError
+
+ log_path = './logs/' + time_str
+
+ is_debug = True if sys.gettrace() else False
+ if not is_debug:
+ writer = SummaryWriter(log_path)
+ logger = TensorboardLogger(
+ writer=writer,
+ train_interval=args.log_interval,
+ update_interval=args.log_interval
+ )
+ # backup the config file, os.path.join(,)
+ shutil.copy(args.config, log_path) # config file
+ shutil.copy("model.py", log_path) # network
+ shutil.copy("arguments.py", log_path) # network
+ else:
+ logger = LazyLogger()
+
+ # ======== callback functions used during training =========
+ def train_fn(epoch, env_step):
+ # monitor leraning rate in tensorboard
+ # writer.add_scalar('train/lr', optim.param_groups[0]["lr"], env_step)
+ pass
+
+ def save_best_fn(policy):
+ if not is_debug:
+ torch.save(policy.state_dict(), os.path.join(log_path, 'policy_step_best.pth'))
+ else:
+ pass
+
+ def final_save_fn(policy):
+ torch.save(policy.state_dict(), os.path.join(log_path, 'policy_step_final.pth'))
+
+ def save_checkpoint_fn(epoch, env_step, gradient_step):
+ if not is_debug:
+ # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
+ ckpt_path = os.path.join(log_path, "checkpoint.pth")
+ # Example: saving by epoch num
+ # ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth")
+ torch.save({"model": policy.state_dict(), "optim": optim.state_dict()}, ckpt_path)
+ return ckpt_path
+ else:
+ return None
+
+ def watch(train_info):
+ print("Setup test envs ...")
+ policy.eval()
+ test_envs.seed(args.seed)
+ print("Testing agent ...")
+ test_collector.reset()
+ result = test_collector.collect(n_episode=1000)
+ ratio = result["ratio"]
+ ratio_std = result["ratio_std"]
+ total = result["num"]
+ print(f"The result (over {result['n/ep']} episodes): ratio={ratio}, ratio_std={ratio_std}, total={total}")
+ with open(os.path.join(log_path, f"{ratio:.4f}_{ratio_std:.4f}_{total}.txt"), "w") as file:
+ file.write(str(train_info).replace("{", "").replace("}", "").replace(", ", "\n"))
+
+ buffer = VectorReplayBuffer(total_size=10000, buffer_num=len(train_envs))
+ train_collector = PackCollector(policy, train_envs, buffer)
+ test_collector = PackCollector(policy, test_envs)
+
+ # trainer
+ result = onpolicy_trainer(
+ policy,
+ train_collector,
+ test_collector,
+ max_epoch=args.train.epoch,
+ step_per_epoch=args.train.step_per_epoch,
+ repeat_per_collect=args.train.repeat_per_collect,
+ episode_per_test=10, # args.test_num,
+ batch_size=args.train.batch_size,
+ step_per_collect=args.train.step_per_collect,
+ # episode_per_collect=args.episode_per_collect,
+ train_fn=train_fn,
+ save_best_fn=save_best_fn,
+ save_checkpoint_fn=save_checkpoint_fn,
+ logger=logger,
+ test_in_train=False
+ )
+
+ final_save_fn(policy)
+ pprint.pprint(f'Finished training! \n{result}')
+ watch(result)
+
+
+if __name__ == '__main__':
+ registration_envs()
+ args = arguments.get_args()
+ args.train.algo = args.train.algo.upper()
+ args.train.step_per_collect = args.train.num_processes * args.train.num_steps
+
+ train(args)