Skip to content

Commit

Permalink
Merge pull request #50 from llauraa23/main
Browse files Browse the repository at this point in the history
Modify step3 of RLHF: support using fined-tuned models from step 1 and 2.
  • Loading branch information
goldmermaid authored Sep 10, 2023
2 parents c93083d + 6896163 commit ef820de
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 14 deletions.
4 changes: 2 additions & 2 deletions example/rlhf/demo_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@

# use huggingface sft and reward model
config = pykoi.RLHFConfig(
base_model_path="elinas/llama-7b-hf-transformers-4.29", # "elinas/llama-7b-hf-transformers-4.29",
base_model_path="models/rlhf_step1_sft", #"elinas/llama-7b-hf-transformers-4.29",
dataset_type="huggingface",
dataset_name="goldmermaid/stack_exchange_rank_10k_dataset",
dataset_subset_rl="data",
reward_model_path="cambioml/rlhf-reward-model",
reward_model_path="models/rlhf_step2_rw/", #"cambioml/rlhf_reward_model",
save_freq=1,
ppo_batch_size=32,
ppo_epochs=4,
Expand Down
2 changes: 1 addition & 1 deletion example/rlhf/demo_rw_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,6 @@

# run reward model finetuning
# config = pykoi.RLHFConfig(dataset_type="local_db")
config = pykoi.RLHFConfig()
config = pykoi.RLHFConfig(reward_model_path = "databricks/dolly-v2-3b")
rlhf_step2_rft = pykoi.RewardFinetuning(config)
rlhf_step2_rft.train_and_save("./models/rlhf_step2_rw")
2 changes: 1 addition & 1 deletion example/rlhf/supervised_finetuning_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,6 @@
print("My local database has {} samples in total".format(my_data_pd.shape[0]))

# run supervised finetuning
config = pykoi.RLHFConfig(base_model_path="elinas/llama-7b-hf-transformers-4.29", dataset_type="local_db")
config = pykoi.RLHFConfig(base_model_path="databricks/dolly-v2-3b", dataset_type="local_db")
rlhf_step1_sft = pykoi.SupervisedFinetuning(config)
rlhf_step1_sft.train_and_save("./models/rlhf_step1_sft")
148 changes: 138 additions & 10 deletions pykoi/rlhf/rl_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
)
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer
from trl.core import LengthSampler
from huggingface_hub import hf_hub_download
from transformers import AutoModelForCausalLM
from peft import PeftModel, PeftConfig, AutoPeftModelForCausalLM
from pykoi.telemetry.telemetry import Telemetry
from pykoi.telemetry.events import (
RLStartEvent,
Expand Down Expand Up @@ -76,12 +79,56 @@ def __init__(self,
rlhf_config.reward_model_path
)
self.reward_dataset = self.create_dataset(self.reward_tokenizer)
self.reward_model = AutoModelForSequenceClassification.from_pretrained(
rlhf_config.reward_model_path,
num_labels=1,
load_in_8bit=True,
device_map={"": Accelerator().local_process_index},
# self.reward_model = AutoModelForSequenceClassification.from_pretrained(
# rlhf_config.reward_model_path,
# num_labels=1,
# load_in_8bit=True,
# device_map={"": Accelerator().local_process_index},
# )

reward_model_path = rlhf_config.reward_model_path

try:
# If there is a trained peft adapter in the hub, load its config.
remote_adapter_config_reward = hf_hub_download(reward_model_path, "adapter_config.json")
except:
remote_adapter_config_reward = None


local_adapter_present_reward = os.path.exists(
os.path.join(reward_model_path, "adapter_config.json")
)

# # Load the trained peft adapter config
if local_adapter_present_reward:
trained_adapter_config_reward = PeftConfig.from_pretrained(reward_model_path)
else:
trained_adapter_config = PeftConfig.from_pretrained(remote_adapter_config_reward)

## Load the pretrained base model
pretrained_kwargs_reward = {
"num_labels": 1,
"load_in_8bit": False, #True,
"device_map": {"": Accelerator().local_process_index},
} # TODO: ADD
pretrained_model_reward = AutoModelForSequenceClassification.from_pretrained(
trained_adapter_config_reward.base_model_name_or_path,
**pretrained_kwargs_reward
)
## TODO: LOAD MERGED BASE MODEL FROM STEP 2

# Load the Peft model by combing the base model with the trained adapter
reward_model = PeftModel.from_pretrained(pretrained_model_reward, reward_model_path, is_trainable=False) # TODO: fix this. This should not be trainable.
self.reward_model = reward_model.merge_and_unload()
#pretrained_model.print_trainable_parameters()
print("\nTrained peft adapter loaded for reward model\n")
# have to specify the pad_token_id or will lead to error: "Cannot handle batch sizes > 1 if no padding token is defined"
# see https://stackoverflow.com/questions/68084302/assertionerror-cannot-handle-batch-sizes-1-if-no-padding-token-is-defined
self.reward_model.config.pad_token_id = self.reward_tokenizer.pad_token_id




self.reward_kwargs = {
"top_k": None,
"function_to_apply": "none",
Expand All @@ -102,12 +149,93 @@ def __init__(self,
## Load the base model and tokenizer and define the PPO Trainer for RL
self.base_tokenizer = self.create_tokenizer(rlhf_config.base_model_path)
self.base_dataset = self.create_dataset(self.base_tokenizer)
self.base_model = AutoModelForCausalLMWithValueHead.from_pretrained(
rlhf_config.base_model_path,
load_in_8bit=rlhf_config.load_in_8bit,
device_map={"": Accelerator().local_process_index},
peft_config=rlhf_config.lora_config_rl,

pretrained_model_name_or_path = rlhf_config.base_model_path
# #NOTE: TODO: peft config will be directly inferred from the pre-trained model. rlhf_config.lora_config_rl will be ignored in previous implementation. Do we want to use it, in the flow of using merged model as base model and then add peft adapter again??

pretrained_kwargs = {
"load_in_8bit": rlhf_config.load_in_8bit,
"device_map": {"": Accelerator().local_process_index},
}

assert isinstance(pretrained_model_name_or_path, str), "The `pretrained_model_path` should be a string."
try:
# If there is a trained peft adapter in the hub, load its config.
remote_adapter_config = hf_hub_download(pretrained_model_name_or_path, "adapter_config.json")
except:
remote_adapter_config = None


local_adapter_present = os.path.exists(
os.path.join(pretrained_model_name_or_path, "adapter_config.json")
)

# # Load the trained peft adapter config
if local_adapter_present:
trained_adapter_config = PeftConfig.from_pretrained(pretrained_model_name_or_path)
else:
trained_adapter_config = PeftConfig.from_pretrained(remote_adapter_config)

# # Load the pretrained base model
pretrained_model = AutoModelForCausalLM.from_pretrained(
trained_adapter_config.base_model_name_or_path,
**pretrained_kwargs
)

# Load the Peft model by combing the base model with the trained adapter
is_trainable = True # TODO: If following merge+train new adapter flow. Below should not be trainable!
pretrained_model = PeftModel.from_pretrained(pretrained_model, pretrained_model_name_or_path, is_trainable=is_trainable)

#pretrained_model.print_trainable_parameters()
print("\nTrained peft adapter loaded for policy model\n")

# Alternatively, load a peft model from a local path. See https://huggingface.co/docs/peft/quicktour. # TODO: DELETE. doesn't work
# peft_model = AutoPeftModelForCausalLM.from_pretrained(pretrained_model_name_or_path)


# Add value head to the pretrained peft model to create a policy network.
if isinstance(pretrained_model, PeftModel):
is_peft_model = True
trl_model_args = {} # args for the value head
# TODO: weights of v_head initialized using v_head_init_strategy="random" by default. trl also suports initialization using "norm".
model = AutoModelForCausalLMWithValueHead(pretrained_model, **trl_model_args)
# TODO: 1 VALUE HEAD REQURIES GRAD = FALSE AND NOT IN CUDA. CHECK IF BELOW CODE FIX THIS. 2. PEFTMODEL PRINT TRAINABLE PARAMETERS REUTRNS ... AND NONE


# For back compatibility for class AutoModelForCausalLMWithValueHead. is_peft_model needs to be specified or calling model.state_dict() will fail.
model.is_peft_model = is_peft_model
# For back compatibility
model.is_sequential_parallel = True
model.current_device = Accelerator().local_process_index
reward_adapter = None # TODO: Consider adding reward adapter here?
if is_peft_model and reward_adapter is not None:
model.add_and_load_reward_modeling_adapter(reward_adapter)
model.supports_rm_adapter = True
else:
model.supports_rm_adapter = False


# Adding v_head to device and register hook. See AutoModelForCausalLMWithValueHead.post_init().
# TODO: is register_forward_hook necessary? outputs should be already on cuda
first_device = list(set(model.pretrained_model.hf_device_map.values()))[0]
model.v_head = model.v_head.to(first_device)
def set_device_hook(module, input, outputs):
new_output = ()
for output in outputs:
if isinstance(output, torch.Tensor):
new_output += (output.to(first_device),)
else:
new_output += (output,)
return new_output
model.register_forward_hook(set_device_hook)
self.base_model = model
#breakpoint()
# self.base_model = AutoModelForCausalLMWithValueHead.from_pretrained(
# rlhf_config.base_model_path,
# load_in_8bit=rlhf_config.load_in_8bit,
# device_map={"": Accelerator().local_process_index},
# peft_config=rlhf_config.lora_config_rl,
# )
self.ppo_trainer = PPOTrainer(
config=self.ppo_config,
model=self.base_model,
Expand Down

0 comments on commit ef820de

Please sign in to comment.