-
Notifications
You must be signed in to change notification settings - Fork 44
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SFT for D2L + Pre-Training (rename of the previous SFT) #102
base: main
Are you sure you want to change the base?
Conversation
execute with "python -m example.rlhf.supervised_finetuning_d2l"
Temporarily use all entries in the dataset as training dataset (i.e., no eval)
…ers into a csv file
…ction, whether to disable evalution configurable
… Use trl DataCollatorForCompletionOnlyLM instead of customized one. Debug: cannot use ConstantLengthDataset or packing when using DataCollatorForCompletionOnly
example/autorate/auto-rater.ipynb
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
qq: what is this .ipynb file for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
qq: what is this .ipynb file for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is used to generate synthetic immigration data by rephrasing.
pykoi/rlhf/config.py
Outdated
@@ -5,6 +5,7 @@ | |||
|
|||
from accelerate import Accelerator | |||
from peft import LoraConfig, TaskType | |||
# TODO: DH: num_train_epochs=20, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: what is this comment code for?
@@ -0,0 +1,40 @@ | |||
from typing import Any, Dict, List, Union |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
qq: do we still need this customized collator per our discussion,
pykoi/rlhf/pre_traning.py
Outdated
seq_length=args.max_seq_length, | ||
# chars_per_token=chars_per_token, | ||
) | ||
return {"train": train_dataset, "eval": eval_dataset} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: need a new line. Make sure you setup your linter properly as we discussed.
pykoi/rlhf/supervised_finetuning.py
Outdated
f" Answer: {example[self._rlhf_config.answer_title]}") | ||
return text | ||
|
||
def prepare_d2l_text(self, example): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: let's rename this method because it can be used for other things.
Also, please add what you have tested for this PR. |
pykoi/rlhf/supervised_finetuning.py
Outdated
from pykoi.chat.db.constants import (QA_CSV_HEADER_ANSWER, QA_CSV_HEADER_ID, | ||
QA_CSV_HEADER_QUESTION, | ||
QA_CSV_HEADER_VOTE_STATUS) | ||
from pykoi.chat.db.qa_database import QuestionAnswerDatabase | ||
from pykoi.rlhf.config import RLHFConfig | ||
from pykoi.telemetry.events import SFTStartEvent, SFTStopEvent | ||
from pykoi.telemetry.telemetry import Telemetry | ||
from trl import DataCollatorForCompletionOnlyLM | ||
# from pykoi.rlhf.customize_data_collator import DataCollatorForCompletionOnlyLM |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: let's remove non-used code.
pykoi/rlhf/supervised_finetuning.py
Outdated
# resize the token embeddings to include the added special tokens | ||
self.model.resize_token_embeddings(len(self.tokenizer)) | ||
data_collator = None | ||
if self._rlhf_config.data_collator == "DataCollatorForCompletionOnlyLM": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should consider to set data_collator
to DataCollatorForCompletionOnlyLM
class instead of a string for SFT training argument.
Then, here you should check None
. Also, it looks like a bug for me that if people use SFT without passing in data_collator. Therefore, you should set proper default value in the config.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that a class is better than a string in the argument file.
I believe if set to None, the default Datacollator will be used when "None" is passed to trl.SFTTrainer. Since default Datacollator also depends on other parameters such as "pack", setting it to None by default makes more sense than a fixed class.
…r, we have to initialize the SFTTrainer in another way
Implement SFT and use D2L as a demo case. Rename previous SFT to Pre-training and modify corresponding scripts/notebooks.