Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proposed synchronisation of RL and LM approaches #4

Open
wants to merge 24 commits into
base: main
Choose a base branch
from

Conversation

thomfoster
Copy link
Collaborator

@thomfoster thomfoster commented Jan 29, 2023

A reinforcement learning algorithm is characterised by the trajectories it generates during training. We are interested in "algorithm distillation" - whether trajectories can be modelled by transformers, as studied in the original deepmind algorithm distillation paper.

My particular interest in this field is the case where:

  1. the trajectories have been generated by the TRLx library during RLHF training of language models
  2. the transformer modelling the trajectories is itself a standard language model

The current repo doesn't account for this, with:

  1. the trajectories being generated for traditional rl environments using OpenAI gym. In the current repo, these are collected online during AD training, which is infeasible for TRLx.
  2. the transformer modelling the trajectories being a gpt2 model with state, action, reward heads attached

This pull request is designed to synchronise these approaches to allow the exploration of both RL and LM tasks, and for the distillation of trajectories into both RL and LM transformers.

There's still a bit more to do but I wanted to open this PR to show the direction I was working in

More detail on data formats

A trajectory is typically defined as a list of (state, action, reward) triples. For training purposes, it is sometimes useful to augment this to include logprobs, which is, for each triple (s, a, r), the probability of taking action $a$ at state $s$ as determined the policy generating the trajectory.

We therefore define an RL Format Trajectory as a sequence of (state, action, reward, logprobs) tuples.

The typical way to learn to model these trajectories with a transformer is to seperately map the final hidden state using 3 different heads. That is, for a given triple (s,a,r,l) a transformer $f$ maps to $(\hat{s}, \hat{a}, \hat{r}, \hat{l})$.

In this repo, this is done via the models in /models/rl.

We are also interested in the ability of standard language models (with language modeling heads) to learn trajectories. To this end we define a Language Format Trajectory as a trajectory serialised into a string. There are many possible ways to do this, and the optimal one requires investigation. For example, for trajectories generated using TRLx when finetuning a language model on positive sentiment, we can format the trajectory as the string:

prompt: Dan went down to the shops.
completion: He smiled as he walked - the sun was shining.
reward: 0.9975
###

It's less obvious how to do this when the task is not a language task, such as moonlander. Enumerating the states as coordinates might work, but requires experimentation.

Trajectories in Language format are learnt by models in /models/lm.

To summarise:

/models contains the "algorithm distillation models", transformers that are trained in a supervised fashion to learn RL trajectories. We distinguish between models that operate on RL Format trajectories and Language format trajectories.

/tasks contains code to produce the RL trajectories that the models learn. It can store this data however it likes, but each task should expose a torch.utils.data.Dataset that can return trajectory data in either RL Format or Language format.

Generating trajectory data

I am using my own fork of TRLx that has rollout logging.

ToDo:

Still to do:
[X] Set up repo structure (just for your language stuff, @h can add in his)
[X] Add train script for models/lm/casually
[X] Clone H's work and merge with @h (/models/rl) and (/tasks/rl)
[ ] Write a train script that demonstrates how to use with env tasks
[ ] Switch to using official branch of TRLx (get rollout logging PR approved)
[ ] Add online evaluation script for models/lm/casuallm
[ ] Improve train script logging to include reward accuracy

Potential future tasks:
[ ] Add more elegant meta class switching between ...LanguageTrajectories and ...RlTrajectories
[ ] Add main file with click CLI interface for running experiments

@thomfoster thomfoster changed the title Synchronise RL and LM approaches Proposed synchronisation of RL and LM approaches Jan 29, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant