Skip to content

Latest commit

 

History

History
81 lines (69 loc) · 3.6 KB

README.md

File metadata and controls

81 lines (69 loc) · 3.6 KB

Official code for Optimal Transport Reward labeling (OTR)

image

Note

This repository is published for archival purpose, and captures the version of the software used for reproducing the results. There is no intent from the author to continue maintainence and development in this repository. If you are interested in using OTR alongside other RL algorithms in JAX, you may want to check out https://github.com/ethanluoyc/corax, which includes the an implementation of OTR as well as many RL algorithms.

This repository includes the official JAX implementation of Optimal Transport for Offline Imitation Learning by Yicheng Luo, Zhengyao Jiang, Samuel Cohen, Edward Grefenstette, Marc Peter Deisenroth.

If you find this repository useful, please cite our paper

@inproceedings{
luo2023otr,
title={Optimal Transport for Offline Imitation Learning},
author={Yicheng Luo and Zhengyao Jiang and Samuel Cohen and Edward Grefenstette and Marc Peter Deisenroth},
booktitle={The Eleventh International Conference on Learning Representations },
year={2023},
url={https://openreview.net/forum?id=MhuFzFsrfvH}
}

Updates

  • 2024/04/10: The code has been updated to work with Python 3.10 and JAX 0.4.19. In addition, dependency on acme has been replaced in favor of corax. Check out the Releases tab for the version of the repository that was used for the original paper submission.

Installation

  1. Follow the instruction to install mujoco 2.1.0.

  2. Create a python virtual environment by running

python -m venv venv
source venv/bin/activate

The code is tested with Python 3.10.

  1. Install the dependencies by running
# Installing the runtime dependencies (pinned)
pip install -r requirements.txt

The requirements files are generated from the requirements/*.in files with uv pip compile to ensure reproducible dependency resolution. If that's not needed, you may find the dependencies needed for running the project from the requirements/*.in files.

Running the experiments

To reproduce results in the paper, you can run

python -m otr.train_offline \
    # Directory to save the logs
    --workdir=/tmp/otr \
    # A ml_collection configuration file
    --config=otr/configs/otil_iql_mujoco.py \
    # D4RL dataset to retrieve the expert dataset
    --config.expert_dataset_name='hopper-medium-replay-v2' \
    # Number of expert episodes to use from the expert dataset
    --config.k=10 \
    # D4RL dataset to retrieve the unlabeled dataset
    --config.offline_dataset_name='hopper-medium-replay-v2' \
    # If false, use reward from the original dataset
    --config.use_dataset_reward=True

Please refer to the configuration files in otr/configs for more configuration that you can override.

Repository Overview

The reference OTR implementation is located in otr/agents/otil/rewarder.py. Under the hood, it uses OTT-JAX for solving the Optimal Transport problem and transform the optimal transport solution to rewards that can be used by an offline RL agent.

Licenses and Acknowledgements

The code is licensed under the MIT license. The IQL implementation is based on https://github.com/ikostrikov/implicit_q_learning which is under the MIT license.