-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Proposed synchronisation of RL and LM approaches #4
Open
thomfoster
wants to merge
24
commits into
CarperAI:main
Choose a base branch
from
thomfoster:main
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
…sentiment_rollouts.py to reflect that this is the script to generate data, not the class that uses it
… just roc stories
…oaches Task/synchronize rl and lm approaches
thomfoster
changed the title
Synchronise RL and LM approaches
Proposed synchronisation of RL and LM approaches
Jan 29, 2023
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
A reinforcement learning algorithm is characterised by the trajectories it generates during training. We are interested in "algorithm distillation" - whether trajectories can be modelled by transformers, as studied in the original deepmind algorithm distillation paper.
My particular interest in this field is the case where:
The current repo doesn't account for this, with:
This pull request is designed to synchronise these approaches to allow the exploration of both RL and LM tasks, and for the distillation of trajectories into both RL and LM transformers.
There's still a bit more to do but I wanted to open this PR to show the direction I was working in
More detail on data formats
A trajectory is typically defined as a list of$a$ at state $s$ as determined the policy generating the trajectory.
(state, action, reward)
triples. For training purposes, it is sometimes useful to augment this to includelogprobs
, which is, for each triple(s, a, r)
, the probability of taking actionWe therefore define an RL Format Trajectory as a sequence of
(state, action, reward, logprobs)
tuples.The typical way to learn to model these trajectories with a transformer is to seperately map the final hidden state using 3 different heads. That is, for a given triple$f$ maps to $(\hat{s}, \hat{a}, \hat{r}, \hat{l})$ .
(s,a,r,l)
a transformerIn this repo, this is done via the models in
/models/rl
.We are also interested in the ability of standard language models (with language modeling heads) to learn trajectories. To this end we define a Language Format Trajectory as a trajectory serialised into a string. There are many possible ways to do this, and the optimal one requires investigation. For example, for trajectories generated using TRLx when finetuning a language model on positive sentiment, we can format the trajectory as the string:
It's less obvious how to do this when the task is not a language task, such as moonlander. Enumerating the states as coordinates might work, but requires experimentation.
Trajectories in Language format are learnt by models in
/models/lm
.To summarise:
/models
contains the "algorithm distillation models", transformers that are trained in a supervised fashion to learn RL trajectories. We distinguish between models that operate on RL Format trajectories and Language format trajectories./tasks
contains code to produce the RL trajectories that the models learn. It can store this data however it likes, but each task should expose atorch.utils.data.Dataset
that can return trajectory data in either RL Format or Language format.Generating trajectory data
I am using my own fork of TRLx that has rollout logging.
ToDo:
Still to do:
[X] Set up repo structure (just for your language stuff, @h can add in his)
[X] Add train script for models/lm/casually
[X] Clone H's work and merge with @h (/models/rl) and (/tasks/rl)
[ ] Write a train script that demonstrates how to use with env tasks
[ ] Switch to using official branch of TRLx (get rollout logging PR approved)
[ ] Add online evaluation script for models/lm/casuallm
[ ] Improve train script logging to include reward accuracy
Potential future tasks:
[ ] Add more elegant meta class switching between ...LanguageTrajectories and ...RlTrajectories
[ ] Add main file with click CLI interface for running experiments