Note
This repository is published for archival purpose, and captures the version of the software used for reproducing the results. There is no intent from the author to continue maintainence and development in this repository. If you are interested in using OTR alongside other RL algorithms in JAX, you may want to check out https://github.com/ethanluoyc/corax, which includes the an implementation of OTR as well as many RL algorithms.
This repository includes the official JAX implementation of Optimal Transport for Offline Imitation Learning by Yicheng Luo, Zhengyao Jiang, Samuel Cohen, Edward Grefenstette, Marc Peter Deisenroth.
If you find this repository useful, please cite our paper
@inproceedings{
luo2023otr,
title={Optimal Transport for Offline Imitation Learning},
author={Yicheng Luo and Zhengyao Jiang and Samuel Cohen and Edward Grefenstette and Marc Peter Deisenroth},
booktitle={The Eleventh International Conference on Learning Representations },
year={2023},
url={https://openreview.net/forum?id=MhuFzFsrfvH}
}
- 2024/04/10: The code has been updated to work with Python 3.10 and JAX 0.4.19. In addition, dependency on acme has been replaced in favor of corax. Check out the Releases tab for the version of the repository that was used for the original paper submission.
-
Follow the instruction to install mujoco 2.1.0.
-
Create a python virtual environment by running
python -m venv venv
source venv/bin/activate
The code is tested with Python 3.10.
- Install the dependencies by running
# Installing the runtime dependencies (pinned)
pip install -r requirements.txt
The requirements files are generated from the requirements/*.in
files with
uv pip compile
to ensure
reproducible dependency resolution. If that's not needed, you may find the
dependencies needed for running the project from the requirements/*.in
files.
To reproduce results in the paper, you can run
python -m otr.train_offline \
# Directory to save the logs
--workdir=/tmp/otr \
# A ml_collection configuration file
--config=otr/configs/otil_iql_mujoco.py \
# D4RL dataset to retrieve the expert dataset
--config.expert_dataset_name='hopper-medium-replay-v2' \
# Number of expert episodes to use from the expert dataset
--config.k=10 \
# D4RL dataset to retrieve the unlabeled dataset
--config.offline_dataset_name='hopper-medium-replay-v2' \
# If false, use reward from the original dataset
--config.use_dataset_reward=True
Please refer to the configuration files in otr/configs for more configuration that you can override.
The reference OTR implementation is located in otr/agents/otil/rewarder.py. Under the hood, it uses OTT-JAX for solving the Optimal Transport problem and transform the optimal transport solution to rewards that can be used by an offline RL agent.
The code is licensed under the MIT license. The IQL implementation is based on https://github.com/ikostrikov/implicit_q_learning which is under the MIT license.