Skip to content

Commit

Permalink
Merge pull request #86 from kaist-silab/refac
Browse files Browse the repository at this point in the history
`v0.1.0`
  • Loading branch information
fedebotu authored Jul 22, 2023
2 parents 5de9a04 + 6080c05 commit a54e687
Show file tree
Hide file tree
Showing 150 changed files with 4,113 additions and 2,813 deletions.
12 changes: 12 additions & 0 deletions .github/codecov.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
coverage:
status:
project:
default:
# set target (e.g. 60%) to fail the build if coverage is too low
target: 60%
patch:
default:
# basic just to show current patch
target: auto
threshold: 0%
base: auto
8 changes: 7 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,10 @@ jobs:
pip install -e ".[testing]"
- name: Run pytest
run: pytest tests/test_*
run: pytest --cov=rl4co tests/*.py

- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v3
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}

50 changes: 20 additions & 30 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
<div align="center">

![rl4co_titlebar_withlogo](https://github.com/kaist-silab/rl4co/assets/34462374/58e087eb-8791-4e92-a9da-fe0f680a11e4)
<img src="https://github.com/kaist-silab/rl4co/assets/34462374/249462ea-b15d-4358-8a11-6508903dae58" style="width:40%">
</br></br>


<a href="https://pytorch.org/get-started/locally/"><img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-ee4c2c?logo=pytorch&logoColor=white"></a>
<a href="https://pytorchlightning.ai/"><img alt="Lightning" src="https://img.shields.io/badge/-Lightning-792ee5?logo=pytorchlightning&logoColor=white"></a>
<a href="https://github.com/pytorch/rl"><img alt="base: TorchRL" src="https://img.shields.io/badge/base-TorchRL-red">
<a href="https://hydra.cc/"><img alt="config: Hydra" src="https://img.shields.io/badge/config-Hydra-89b8cd"></a> [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)[![Slack](https://img.shields.io/badge/slack-chat-611f69.svg?logo=slack)](https://join.slack.com/t/rl4co/shared_invite/zt-1ytz2c1v4-0IkQ8NQH4TRXIX8PrRmDhQ)
![license](https://img.shields.io/badge/license-Apache%202.0-green.svg?)<a href="https://colab.research.google.com/github/kaist-silab/rl4co/blob/main/notebooks/1-quickstart.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a>[![PyPI](https://img.shields.io/pypi/v/rl4co?logo=pypi)](https://pypi.org/project/rl4co)
<a href="https://hydra.cc/"><img alt="config: Hydra" src="https://img.shields.io/badge/config-Hydra-89b8cd"></a> [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) [![Slack](https://img.shields.io/badge/slack-chat-611f69.svg?logo=slack)](https://join.slack.com/t/rl4co/shared_invite/zt-1ytz2c1v4-0IkQ8NQH4TRXIX8PrRmDhQ)
![license](https://img.shields.io/badge/license-Apache%202.0-green.svg?) <a href="https://colab.research.google.com/github/kaist-silab/rl4co/blob/main/notebooks/1-quickstart.ipynb"> <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a> [![PyPI](https://img.shields.io/pypi/v/rl4co?logo=pypi)](https://pypi.org/project/rl4co)
[![Test](https://github.com/kaist-silab/rl4co/actions/workflows/tests.yml/badge.svg)](https://github.com/kaist-silab/rl4co/actions/workflows/tests.yml)
<!-- ![testing](https://github.com/kaist-silab/ncobench/actions/workflows/tests.yml/badge.svg) -->

Expand All @@ -26,8 +28,7 @@ RL4CO is built upon:
- [PyTorch Lightning](https://github.com/Lightning-AI/lightning): a lightweight PyTorch wrapper for high-performance AI research
- [Hydra](https://github.com/facebookresearch/hydra): a framework for elegantly configuring complex applications

![image](https://github.com/kaist-silab/rl4co/assets/48984123/0db4efdd-1c93-4991-8f09-f3c6c1f35d60)

![RL4CO Overview](https://github.com/kaist-silab/rl4co/assets/34462374/4d9a670f-ab7c-4fc8-9135-82d17cb6d0ee)

## Getting started
<a href="https://colab.research.google.com/github/kaist-silab/rl4co/blob/main/notebooks/1-quickstart.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a>
Expand Down Expand Up @@ -105,41 +106,30 @@ python run.py -m experiment=tsp/am train.optimizer.lr=1e-3,1e-4,1e-5

### Minimalistic Example

Here is a minimalistic example training the Attention Model with greedy rollout baseline on TSP in less than 50 lines of code:
Here is a minimalistic example training the Attention Model with greedy rollout baseline on TSP in less than 30 lines of code:

```python
from omegaconf import DictConfig
import lightning as L
from rl4co.envs import TSPEnv
from rl4co.models.zoo.am import AttentionModel
from rl4co.tasks.rl4co import RL4COLitModule

config = DictConfig(
{"data": {
"train_size": 100000,
"val_size": 10000,
"batch_size": 512,
},
"optimizer": {"lr": 1e-4}}
)
from rl4co.models import AttentionModel
from rl4co.utils import RL4COTrainer

# Environment, Model, and Lightning Module
env = TSPEnv(num_loc=20)
model = AttentionModel(env)
lit_module = RL4COLitModule(config, env, model)
model = AttentionModel(env,
baseline="rollout",
train_data_size=100_000,
test_data_size=10_000,
optimizer_kwargs={'lr': 1e-4}
)

# Trainer
trainer = L.Trainer(
max_epochs=3, # only few epochs
accelerator="gpu", # use GPU if available, else you can use others as "cpu"
logger=None, # can replace with WandbLogger, TensorBoardLogger, etc.
precision="16-mixed", # Lightning will handle faster training with mixed precision
gradient_clip_val=1.0, # clip gradients to avoid exploding gradients
reload_dataloaders_every_n_epochs=1, # necessary for sampling new data
)
trainer = RL4COTrainer(max_epochs=3)

# Fit the model
trainer.fit(lit_module)
trainer.fit(model)

# Test the model
trainer.test(model)
```


Expand Down
3 changes: 3 additions & 0 deletions configs/experiment/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Refactored Experiments

We made some major refactoring to RL4CO, so the older experiments versions will be updated to the more efficient standards. You may refer to the [older experiments](archive/README.md) to run the same as in our preprint.
9 changes: 9 additions & 0 deletions configs/experiment/archive/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Older experiment versions

These experiments are the ones we ran in the first version of our paper. The only difference is that, from version `0.1.0`, we added several new features and made a major refactoring that simplifies our codebase!

We will update the experiments with the refactored versions. To use these, you may use RL4CO no greater than version `0.0.6`:

```bash
pip install rl4co<=0.0.6
```
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ data:
batch_size: 512
train_size: 1_280_000
val_size: 10_000

model:
num_starts: 0 # 0 for no augmentation for multi-starts
num_augment: 10
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ logger:
tags: ${tags}
group: "dpp"
name: "am"

seed: 12345

env:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,3 @@ model:
_target_: rl4co.models.rl.reinforce.critic.CriticNetwork
env: ${env}
use_native_sdpa: ${model.use_native_sdpa}

File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ logger:
tags: ${tags}
group: "mtsp${env.num_loc}"
name: "am"

seed: 12345

env:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ env:
num_loc: 20
min_num_agents: 5
max_num_agents: 5

trainer:
min_epochs: 10
max_epochs: 100
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ data:
batch_size: 512
train_size: 1_280_000
val_size: 10_000

model:
num_starts: 0 # 0 for no augmentation for multi-starts
num_augment: 10
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ logger:
tags: ${tags}
group: "tsp${env.num_loc}"
name: "ham"

seed: 12345

env:
Expand All @@ -33,7 +33,7 @@ trainer:
max_epochs: 100
gradient_clip_val: 1.0
accelerator: "gpu"

train:
optimizer:
# _partial_: True
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ logger:
hydra:
run:
dir: ${paths.log_dir}/${mode}/runs/${logger.wandb.group}/${logger.wandb.name}/${now:%Y-%m-%d}_${now:%H-%M-%S}

seed: 12345

env:
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ defaults:
# - override /logger: null # comment this line to enable logging
- override /logger: wandb.yaml

transfer: # transfer to
transfer: # transfer to
source:
problem: 'cvrp'
size: 50
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ defaults:
# - override /logger: null # comment this line to enable logging
- override /logger: wandb.yaml

transfer: # transfer to
transfer: # transfer to
source:
problem: 'cvrp'
size: 50
Expand Down Expand Up @@ -57,7 +57,7 @@ data:
batch_size: 512
train_size: 1_280_000
val_size: 10_000

model:
num_starts: ${transfer.target.size} # 0 for no augmentation for multi-starts
num_augment: 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ defaults:
# - override /logger: null # comment this line to enable logging
- override /logger: wandb.yaml

transfer: # transfer to
transfer: # transfer to
source:
problem: 'tsp'
size: 50
Expand Down Expand Up @@ -57,7 +57,7 @@ data:
batch_size: 512
train_size: 1_280_000
val_size: 10_000

model:
num_starts: 0 # 0 for no augmentation for multi-starts
num_augment: 10
Expand Down
File renamed without changes.
File renamed without changes.
46 changes: 46 additions & 0 deletions configs/experiment/archive/tsp/am.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# @package _global_

defaults:
- override /model: am.yaml
- override /env: tsp.yaml
- override /callbacks: default.yaml
- override /trainer: default.yaml
# - override /logger: null # comment this line to enable logging
- override /logger: wandb.yaml

env:
num_loc: 50

tags: ["am", "tsp"]

logger:
wandb:
project: "rl4co"
tags: ${tags}
group: "tsp${env.num_loc}"
name: "am-tsp${env.num_loc}"

seed: 12345

trainer:
max_epochs: 100
gradient_clip_val: 1.0
accelerator: "gpu"
precision: "16-mixed"

train:
optimizer:
_target_: torch.optim.Adam
lr: 1e-4
weight_decay: 0
scheduler:
_target_: torch.optim.lr_scheduler.MultiStepLR
milestones: [80, 95]
gamma: 0.1
scheduler_interval: epoch

data:
batch_size: 512
train_size: 1_280_000
val_size: 10_000

Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ logger:
tags: ${tags}
group: "tsp${env.num_loc}"
name: "mdam"

seed: 12345

env:
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ data:
batch_size: 512
train_size: 1_280_000
val_size: 10_000

model:
num_starts: 0 # 0 for no augmentation for multi-starts
num_augment: 10
Expand Down
47 changes: 47 additions & 0 deletions configs/experiment/base.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# @package _global_
# Example configuration for experimenting. Trains the Attention Model on
# the TSP environment with 50 locations via REINFORCE with greedy rollout baseline.
# You may find comments on the most common hyperparameters below.

# Override defaults: take configs from relative path
defaults:
- override /model: am.yaml
- override /env: tsp.yaml
- override /callbacks: default.yaml
- override /trainer: default.yaml
# - override /logger: null # comment this line to enable logging
- override /logger: wandb.yaml

# Environment configuration
# Note that here we load by default the `.npz` files for the TSP environment
# that are automatically generated with seed following Kool et al. (2019).
env:
num_loc: 50
data_dir: ${paths.root_dir}/data/tsp
val_file: tsp${env.num_loc}_val_seed4321.npz
test_file: tsp${env.num_loc}_test_seed1234.npz

# Logging: we use Wandb in this case
logger:
wandb:
project: "rl4co"
tags: ["am", "tsp"]
group: "tsp${env.num_loc}"
name: "am-tsp${env.num_loc}"

# Model: this contains the environment (which gets automatically passed to the model on
# initialization), the policy network and other hyperparameters.
# This is a `LightningModule` and can be trained with PyTorch Lightning.
model:
batch_size: 512
train_data_size: 1_280_000
val_data_size: 10_000
test_data_size: 10_000
optimizer_kwargs:
lr: 1e-4

# Trainer: this is a customized version of the PyTorch Lightning trainer.
trainer:
max_epochs: 100

seed: 1234
Loading

0 comments on commit a54e687

Please sign in to comment.