From 2f42914a0507647d49bba5882edead54b3a63930 Mon Sep 17 00:00:00 2001 From: llauraa23 Date: Sun, 10 Sep 2023 12:17:58 -0700 Subject: [PATCH 1/2] In step3 of RLHF, support using policy model trained in step1, and reward model trained in step2. Due to limited memory with single GPU training, use dolly-v2-3b for base models in step 1 and 2. --- example/rlhf/demo_rl.py | 4 +- example/rlhf/demo_rw_finetuning.py | 1 + example/rlhf/supervised_finetuning_demo.py | 1 + pykoi/rlhf/rl_finetuning.py | 148 +++++++++++++++++++-- 4 files changed, 142 insertions(+), 12 deletions(-) diff --git a/example/rlhf/demo_rl.py b/example/rlhf/demo_rl.py index ac5fc36..48a5b8b 100644 --- a/example/rlhf/demo_rl.py +++ b/example/rlhf/demo_rl.py @@ -5,11 +5,11 @@ # use huggingface sft and reward model config = pykoi.RLHFConfig( - base_model_path="elinas/llama-7b-hf-transformers-4.29", # "elinas/llama-7b-hf-transformers-4.29", + base_model_path="models/rlhf_step1_sft", #"elinas/llama-7b-hf-transformers-4.29", dataset_type="huggingface", dataset_name="goldmermaid/stack_exchange_rank_10k_dataset", dataset_subset_rl="data", - reward_model_path="goldmermaid/rlhf_reward_model", + reward_model_path="models/rlhf_step2_rw/", #"cambioml/rlhf_reward_model", save_freq=1, ppo_batch_size=32, ppo_epochs=4, diff --git a/example/rlhf/demo_rw_finetuning.py b/example/rlhf/demo_rw_finetuning.py index e8422f5..8a4fcb1 100644 --- a/example/rlhf/demo_rw_finetuning.py +++ b/example/rlhf/demo_rw_finetuning.py @@ -24,5 +24,6 @@ # run reward model finetuning # config = pykoi.RLHFConfig(dataset_type="local_db") config = pykoi.RLHFConfig() +config.base_model_path = "databricks/dolly-v2-3b" rlhf_step2_rft = pykoi.RewardFinetuning(config) rlhf_step2_rft.train_and_save("./models/rlhf_step2_rw") diff --git a/example/rlhf/supervised_finetuning_demo.py b/example/rlhf/supervised_finetuning_demo.py index df92619..629b479 100644 --- a/example/rlhf/supervised_finetuning_demo.py +++ b/example/rlhf/supervised_finetuning_demo.py @@ -23,5 +23,6 @@ # run supervised finetuning config = pykoi.RLHFConfig(base_model_path="elinas/llama-7b-hf-transformers-4.29", dataset_type="local_db") +config.base_model_path = "databricks/dolly-v2-3b" rlhf_step1_sft = pykoi.SupervisedFinetuning(config) rlhf_step1_sft.train_and_save("./models/rlhf_step1_sft") diff --git a/pykoi/rlhf/rl_finetuning.py b/pykoi/rlhf/rl_finetuning.py index d60eb56..741946b 100644 --- a/pykoi/rlhf/rl_finetuning.py +++ b/pykoi/rlhf/rl_finetuning.py @@ -25,6 +25,9 @@ ) from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer from trl.core import LengthSampler +from huggingface_hub import hf_hub_download +from transformers import AutoModelForCausalLM +from peft import PeftModel, PeftConfig, AutoPeftModelForCausalLM class RLFinetuning(Trainer): @@ -66,12 +69,56 @@ def __init__(self, rlhf_config: RLHFConfig): rlhf_config.reward_model_path ) self.reward_dataset = self.create_dataset(self.reward_tokenizer) - self.reward_model = AutoModelForSequenceClassification.from_pretrained( - rlhf_config.reward_model_path, - num_labels=1, - load_in_8bit=True, - device_map={"": Accelerator().local_process_index}, + # self.reward_model = AutoModelForSequenceClassification.from_pretrained( + # rlhf_config.reward_model_path, + # num_labels=1, + # load_in_8bit=True, + # device_map={"": Accelerator().local_process_index}, + # ) + + reward_model_path = rlhf_config.reward_model_path + + try: + # If there is a trained peft adapter in the hub, load its config. + remote_adapter_config_reward = hf_hub_download(reward_model_path, "adapter_config.json") + except: + remote_adapter_config_reward = None + + + local_adapter_present_reward = os.path.exists( + os.path.join(reward_model_path, "adapter_config.json") + ) + + # # Load the trained peft adapter config + if local_adapter_present_reward: + trained_adapter_config_reward = PeftConfig.from_pretrained(reward_model_path) + else: + trained_adapter_config = PeftConfig.from_pretrained(remote_adapter_config_reward) + + ## Load the pretrained base model + pretrained_kwargs_reward = { + "num_labels": 1, + "load_in_8bit": False, #True, + "device_map": {"": Accelerator().local_process_index}, + } # TODO: ADD + pretrained_model_reward = AutoModelForSequenceClassification.from_pretrained( + trained_adapter_config_reward.base_model_name_or_path, + **pretrained_kwargs_reward ) + ## TODO: LOAD MERGED BASE MODEL FROM STEP 2 + + # Load the Peft model by combing the base model with the trained adapter + reward_model = PeftModel.from_pretrained(pretrained_model_reward, reward_model_path, is_trainable=False) # TODO: fix this. This should not be trainable. + self.reward_model = reward_model.merge_and_unload() + #pretrained_model.print_trainable_parameters() + print("\nTrained peft adapter loaded for reward model\n") + # have to specify the pad_token_id or will lead to error: "Cannot handle batch sizes > 1 if no padding token is defined" + # see https://stackoverflow.com/questions/68084302/assertionerror-cannot-handle-batch-sizes-1-if-no-padding-token-is-defined + self.reward_model.config.pad_token_id = self.reward_tokenizer.pad_token_id + + + + self.reward_kwargs = { "top_k": None, "function_to_apply": "none", @@ -92,12 +139,93 @@ def __init__(self, rlhf_config: RLHFConfig): ## Load the base model and tokenizer and define the PPO Trainer for RL self.base_tokenizer = self.create_tokenizer(rlhf_config.base_model_path) self.base_dataset = self.create_dataset(self.base_tokenizer) - self.base_model = AutoModelForCausalLMWithValueHead.from_pretrained( - rlhf_config.base_model_path, - load_in_8bit=rlhf_config.load_in_8bit, - device_map={"": Accelerator().local_process_index}, - peft_config=rlhf_config.lora_config_rl, + + pretrained_model_name_or_path = rlhf_config.base_model_path + # #NOTE: TODO: peft config will be directly inferred from the pre-trained model. rlhf_config.lora_config_rl will be ignored in previous implementation. Do we want to use it, in the flow of using merged model as base model and then add peft adapter again?? + + pretrained_kwargs = { + "load_in_8bit": rlhf_config.load_in_8bit, + "device_map": {"": Accelerator().local_process_index}, + } + + assert isinstance(pretrained_model_name_or_path, str), "The `pretrained_model_path` should be a string." + try: + # If there is a trained peft adapter in the hub, load its config. + remote_adapter_config = hf_hub_download(pretrained_model_name_or_path, "adapter_config.json") + except: + remote_adapter_config = None + + + local_adapter_present = os.path.exists( + os.path.join(pretrained_model_name_or_path, "adapter_config.json") + ) + + # # Load the trained peft adapter config + if local_adapter_present: + trained_adapter_config = PeftConfig.from_pretrained(pretrained_model_name_or_path) + else: + trained_adapter_config = PeftConfig.from_pretrained(remote_adapter_config) + + # # Load the pretrained base model + pretrained_model = AutoModelForCausalLM.from_pretrained( + trained_adapter_config.base_model_name_or_path, + **pretrained_kwargs ) + + # Load the Peft model by combing the base model with the trained adapter + is_trainable = True # TODO: If following merge+train new adapter flow. Below should not be trainable! + pretrained_model = PeftModel.from_pretrained(pretrained_model, pretrained_model_name_or_path, is_trainable=is_trainable) + + #pretrained_model.print_trainable_parameters() + print("\nTrained peft adapter loaded for policy model\n") + + # Alternatively, load a peft model from a local path. See https://huggingface.co/docs/peft/quicktour. # TODO: DELETE. doesn't work + # peft_model = AutoPeftModelForCausalLM.from_pretrained(pretrained_model_name_or_path) + + + # Add value head to the pretrained peft model to create a policy network. + if isinstance(pretrained_model, PeftModel): + is_peft_model = True + trl_model_args = {} # args for the value head + # TODO: weights of v_head initialized using v_head_init_strategy="random" by default. trl also suports initialization using "norm". + model = AutoModelForCausalLMWithValueHead(pretrained_model, **trl_model_args) + # TODO: 1 VALUE HEAD REQURIES GRAD = FALSE AND NOT IN CUDA. CHECK IF BELOW CODE FIX THIS. 2. PEFTMODEL PRINT TRAINABLE PARAMETERS REUTRNS ... AND NONE + + + # For back compatibility for class AutoModelForCausalLMWithValueHead. is_peft_model needs to be specified or calling model.state_dict() will fail. + model.is_peft_model = is_peft_model + # For back compatibility + model.is_sequential_parallel = True + model.current_device = Accelerator().local_process_index + reward_adapter = None # TODO: Consider adding reward adapter here? + if is_peft_model and reward_adapter is not None: + model.add_and_load_reward_modeling_adapter(reward_adapter) + model.supports_rm_adapter = True + else: + model.supports_rm_adapter = False + + + # Adding v_head to device and register hook. See AutoModelForCausalLMWithValueHead.post_init(). + # TODO: is register_forward_hook necessary? outputs should be already on cuda + first_device = list(set(model.pretrained_model.hf_device_map.values()))[0] + model.v_head = model.v_head.to(first_device) + def set_device_hook(module, input, outputs): + new_output = () + for output in outputs: + if isinstance(output, torch.Tensor): + new_output += (output.to(first_device),) + else: + new_output += (output,) + return new_output + model.register_forward_hook(set_device_hook) + self.base_model = model + #breakpoint() + # self.base_model = AutoModelForCausalLMWithValueHead.from_pretrained( + # rlhf_config.base_model_path, + # load_in_8bit=rlhf_config.load_in_8bit, + # device_map={"": Accelerator().local_process_index}, + # peft_config=rlhf_config.lora_config_rl, + # ) self.ppo_trainer = PPOTrainer( config=self.ppo_config, model=self.base_model, From 6896163460e6f81cefc8ae775a5416498ade6f70 Mon Sep 17 00:00:00 2001 From: llauraa23 Date: Sun, 10 Sep 2023 15:01:36 -0700 Subject: [PATCH 2/2] resolve comment on config model path. --- example/rlhf/demo_rw_finetuning.py | 3 +-- example/rlhf/supervised_finetuning_demo.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/example/rlhf/demo_rw_finetuning.py b/example/rlhf/demo_rw_finetuning.py index 09f8c1d..c6913c5 100644 --- a/example/rlhf/demo_rw_finetuning.py +++ b/example/rlhf/demo_rw_finetuning.py @@ -26,7 +26,6 @@ # run reward model finetuning # config = pykoi.RLHFConfig(dataset_type="local_db") -config = pykoi.RLHFConfig() -config.base_model_path = "databricks/dolly-v2-3b" +config = pykoi.RLHFConfig(reward_model_path = "databricks/dolly-v2-3b") rlhf_step2_rft = pykoi.RewardFinetuning(config) rlhf_step2_rft.train_and_save("./models/rlhf_step2_rw") diff --git a/example/rlhf/supervised_finetuning_demo.py b/example/rlhf/supervised_finetuning_demo.py index 73c5906..bfc29a3 100644 --- a/example/rlhf/supervised_finetuning_demo.py +++ b/example/rlhf/supervised_finetuning_demo.py @@ -25,7 +25,6 @@ print("My local database has {} samples in total".format(my_data_pd.shape[0])) # run supervised finetuning -config = pykoi.RLHFConfig(base_model_path="elinas/llama-7b-hf-transformers-4.29", dataset_type="local_db") -config.base_model_path = "databricks/dolly-v2-3b" +config = pykoi.RLHFConfig(base_model_path="databricks/dolly-v2-3b", dataset_type="local_db") rlhf_step1_sft = pykoi.SupervisedFinetuning(config) rlhf_step1_sft.train_and_save("./models/rlhf_step1_sft")