This repo contains a unified opensource code implementation of Offline RL algorithms, which is further developed on the basis of d3rlpy (https://github.com/takuseno/d3rlpy). This repo will be constantly updated to include new researches (The development of this repo is in progress at present.)
Standard reinforcement learning (RL) learns how to perform a task through trial and error, balancing exploration and exploitation to achieve better performance. Offline Reinforcement Learning (Offline RL), also known as Batch Reinforcement Learning (BRL), is a variant of Reinforcement Learning that requires an agent to learn to perform tasks from a fixed dataset without exploration. In other words, Offline RL is a data-driven RL paradigm concerned with learning exclusively from static datasets of previously-collected experiences . In the review paper written by Sergey Levine et al. "Offline reinforcement learning: Tutorial, review, and perspectives on open problems", they use the following graph to describe the relationship and difference between Offline RL and standard RL as below:
Following the classification method by Aviral Kumar and Sergey Levine in NeurIPS 2020 Tutorial, we divide the existing Offline RL algorithms into the following four categories :
-
Policy Constraint Methods (PC)
-
Value Function Regularization Methods (VR)
-
Model-based Methods (MB)
-
Uncertainty-based Methods (U)
Besides, we add an additional class of Offline to Online research algorithms:
- Offline to Online (Off2On)
Current deep RL methods still typically rely on active data collection to succeed, hindering their application in the real world especially when the data collection is dangerous or expensive. Offline RL (also known as batch RL) is a data-driven RL paradigm concerned with learning exclusively from static datasets of previously-collected experiences. In this setting, a behavior policy interacts with the environment to collect a set of experiences, which can later be used to learn a policy without further interaction. This paradigm can be extremely valuable in settings where online interaction is impractical. However, current offline rl methods are restricted to three challenges:
- Low upper limit of algorithm: The quality of offline data determines the performance of offline reinforcement learning algorithms. How to expand low-quality offline data without additional interaction to increase the learning upper limit of offline reinforcement learning algorithms?
- Poor algorithm effect: Existing off-policy/offline algorithm trains on the offline data distribution. When interacting with the environment, the distribution of the accessed state-action may change compared with the offline data (Distributional Shift). In this situation, the Q value of the <state, action> pair is easy to be overestimated, which affects the overall performance. How to characterize the data outside the offline data distribution (Out Of Distribution, OOD) to avoid overestimation?
- Difficulty in applying the algorithm: Due to the limited quality of the dataset, the learned strategy cannot be directly deployed in the production environment, and further online learning is required. How to design data sampling in the online training phase to avoid the sudden drop in the initial performance of the strategy due to the redundant data generated by the distribution change, and quickly converge to the optimal solution in a limited number of interactions?
This repository contains the codes of representative benchmarks and algorithms on the topic of Offline Reinforcement Learning. The repository is developed based on d3rlpy following MIT license to shed lights on the research on the above three challenges. While inheriting its advantages, the additional features include (or will be included).
- For people who are insterested in Offline RL, our introduction of each algorithm and our tutorial blogs can be helpful.
- For RL practicers (especially who work on related fields), we provide advanced Offline RL algorithms with strong performance and different kinds of datasets. In detail, we provide contents and supports for:
The algorithms in this repo are all implemented python 3.7 (and versions above). PyTorch is the main DL code frameworks we adopt in this repo with different choices in different algorithms.
First of all, we recommend the user to install anaconada and or venv for convenient management of different python envs.
In this repo, the following data and environments is needed:
- OpenAI Gym (e.g., MuJoCo, Robotics)
- D4RL
- Waymo Datasets
Note that each algorithm may use only one or several environments in the ones listed above. Please refer to the page of specific algorithm for concrete requirements.
To clone this repo:
git clone [email protected]:TJU-DRL-LAB/offline-rl-algorithms.git
Here we introduce how to configure your own dataset and modify the algorithm based on your own design.
- Rewrite d3rlpy.get_dataset.py to add get_your_data function in get_dataset function.
def get_dataset(
env_name: str, create_mask: bool = False, mask_size: int = 1) -> Tuple[MDPDataset, gym.Env]:
if env_name == "existing datasets":
return get_existing_datasets()
elif env_name == "your own datasets":
return get_your_data()
raise ValueError(f"Unrecognized env_name: {env_name}.")
- Load your datasets and transform then into MDPDataset format
def get_your_data():
observations = []
actions = []
rewards = []
terminals = []
episode_terminals = []
episode_step = 0
cursor = 0
dataset_size = dataset["observations"].shape[0]
while cursor < dataset_size:
# collect data for step=t
observation = dataset["observations"][cursor]
action = dataset["actions"][cursor]
if episode_step == 0:
reward = 0.0
else:
reward = dataset["rewards"][cursor - 1]
observations.append(observation)
actions.append(action)
rewards.append(reward)
terminals.append(0.0)
# skip adding the last step when timeout
if dataset["timeouts"][cursor]:
episode_terminals.append(1.0)
episode_step = 0
cursor += 1
continue
episode_terminals.append(0.0)
episode_step += 1
if dataset["terminals"][cursor]:
# collect data for step=t+1
dummy_observation = observation.copy()
dummy_action = action.copy()
next_reward = dataset["rewards"][cursor]
# the last observation is rarely used
observations.append(dummy_observation)
actions.append(dummy_action)
rewards.append(next_reward)
terminals.append(1.0)
episode_terminals.append(1.0)
episode_step = 0
cursor += 1
mdp_dataset = MDPDataset(
observations=np.array(observations, dtype=np.float32),
actions=np.array(actions, dtype=np.float32),
rewards=np.array(rewards, dtype=np.float32),
terminals=np.array(terminals, dtype=np.float32),
episode_terminals=np.array(episode_terminals, dtype=np.float32),
create_mask=create_mask,
mask_size=mask_size,
)
return mdp_dataset, env
- get your own datasets by
from d3rlpy.datasets import get_dataset
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='your dataset')
args = parser.parse_args()
get_dataset(args.dataset)
Assuming you're modifying algorithm based on SAC:
- Create two python file, name them as YourSAC.py and YourSACImpl.py. 其中YourSACImpl.py中指定的YourSACImpl class继承SACImpl.
class YourSACImpl(SACImpl):
def __init__(self, a=A, b=B):
...
- Modify your algo in YourSACImpl.py by overloading compute_critic_loss/compute_actor_loss/other functions.
def compute_critic_loss(self, batch: TorchMiniBatch, q_tpn: torch.Tensor) -> torch.Tensor:
observations = batch.observations
actions = batch.actions
rewards = batch.next_rewards
...
your_critic_loss = critic_loss_func(observations, actions, rewards)
return your_critic_loss
def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor:
observations = batch.observations
actions = batch.actions
...
your_actor_loss = actor_loss_func(observations, actions)
return your_actor_loss
- Import YourSACImpl in YourSAC.py and modify _create_impl function to pass your algorithm parameters to YourSACImpl.py
def _create_impl(self, observation_shape: Sequence[int], action_size: int) -> None:
self._impl = YourSACImpl(a=A, b=B, ...)
self._impl.build()
- Update a liscence
- Update the README files for each branches
- Check the vadality of codes to release
If you use our repo in your work, we ask that you cite our paper.
Here is an example BibTeX:
@article{aaa22xxxx,
author = {tjurllab},
title = {A Unified Repo for Offline RL},
year = {2022},
url = {http://arxiv.org/abs/xxxxxxx},
archivePrefix = {arXiv}
}
[To change]
[To add some acknowledgement]