From 706e843a9f5b27c44db524463c288682301d2b5d Mon Sep 17 00:00:00 2001 From: Preetika764 Date: Fri, 9 Aug 2024 22:25:13 +0000 Subject: [PATCH] Adding support for Gemma2 --- scripts/config.sh | 5 ++++- scripts/train_rosa.sh | 3 +++ src/panza/data_preparation/summarize_emails.py | 6 +++++- src/panza/finetuning/train.py | 6 ++++-- src/panza/utils/prompting.py | 10 ++++++++++ 5 files changed, 26 insertions(+), 4 deletions(-) diff --git a/scripts/config.sh b/scripts/config.sh index 76522f1..a97707d 100755 --- a/scripts/config.sh +++ b/scripts/config.sh @@ -21,6 +21,7 @@ export MODEL_PRECISION=bf16 # precision at which the base model is stored; optio # export PANZA_GENERATIVE_MODEL="mistralai/Mistral-7B-Instruct-v0.2" export PANZA_GENERATIVE_MODEL="ISTA-DASLab/Meta-Llama-3-8B-Instruct" # export PANZA_GENERATIVE_MODEL="microsoft/Phi-3-mini-4k-instruct" +# export PANZA_GENERATIVE_MODEL="google/gemma-2-2b-it" lowercased=$(echo "$PANZA_GENERATIVE_MODEL" | tr '[:upper:]' '[:lower:]') if [[ ${lowercased} == *llama* ]]; then @@ -29,8 +30,10 @@ elif [[ ${lowercased} == *mistral* ]]; then export MODEL_TYPE=mistralv2 elif [[ ${lowercased} == *phi* ]]; then export MODEL_TYPE=phi3 +elif [[ ${lowercased} == *gemma* ]]; then + export MODEL_TYPE=gemma2 else - echo "Model type ${PANZA_GENERATIVE_MODEL} not recognized! Panza only works with Mistral and Llama3 models. Exiting." + echo "Model type ${PANZA_GENERATIVE_MODEL} not recognized! Panza only works with Mistral, Phi, Gemma2, and Llama3 models. Exiting." exit fi diff --git a/scripts/train_rosa.sh b/scripts/train_rosa.sh index d100cd2..02aebb9 100755 --- a/scripts/train_rosa.sh +++ b/scripts/train_rosa.sh @@ -28,6 +28,9 @@ elif [[ ${MODEL_TYPE} == mistralv2 ]]; then elif [[ ${MODEL_TYPE} == phi3 ]]; then export LR=1e-5 # learning rate export LORA_LR=1e-5 # a separate learning rate for the low-rank adapters +elif [[ ${MODEL_TYPE} == gemma2 ]]; then + export LR=1e-5 # learning rate + export LORA_LR=1e-5 # a separate learning rate for the low-rank adapters else echo "Model type ${MODEL_TYPE} not recognized! Panza only works with mistralv2, llama3 and phi3 models. Exiting." exit diff --git a/src/panza/data_preparation/summarize_emails.py b/src/panza/data_preparation/summarize_emails.py index 4585397..8de472e 100644 --- a/src/panza/data_preparation/summarize_emails.py +++ b/src/panza/data_preparation/summarize_emails.py @@ -17,6 +17,7 @@ sys.path.pop(0) MDL = os.environ.get("PANZA_GENERATIVE_MODEL") +MDL_TYPE = os.environ.get("MODEL_TYPE") TEMP = 0.7 TOP_P = 0.7 TOP_K = 50 @@ -43,7 +44,10 @@ def __init__(self, model, dtype, temperature, top_k, top_p, summarization_prompt model, model_max_length=self.model.config.max_position_embeddings, trust_remote_code=True ) self.tokenizer.padding_side = "left" - self.tokenizer.pad_token = self.tokenizer.eos_token + if MDL_TYPE == "gemma2": + self.tokenizer.pad_token = self.tokenizer.bos_token + else: + self.tokenizer.pad_token = self.tokenizer.eos_token self.summarization_prompt = summarization_prompt _, self.prompt_end_wrapper, _, self.response_end_wrapper = ( diff --git a/src/panza/finetuning/train.py b/src/panza/finetuning/train.py index 82c7c23..7aa7db0 100644 --- a/src/panza/finetuning/train.py +++ b/src/panza/finetuning/train.py @@ -166,15 +166,17 @@ def build_composer_peft_model( with init_empty_weights(include_buffers=False): model = AutoModelForCausalLM.from_pretrained( model_config.pretrained_model_name_or_path, - device_map='cpu' if quant_config is None else 'auto', + # device_map='cpu' if quant_config is None else 'auto', torch_dtype=compute_dtype, # load_in_4bit=weight_bias_dtype == '4bit', quantization_config=quant_config, trust_remote_code=True, use_auth_token=True, use_cache=False, - attn_implementation='eager' + attn_implementation='eager', + low_cpu_mem_usage=True, ) + model.tie_weights() print('Model built!') if rosa_config is not None: diff --git a/src/panza/utils/prompting.py b/src/panza/utils/prompting.py index 958b720..4f00726 100644 --- a/src/panza/utils/prompting.py +++ b/src/panza/utils/prompting.py @@ -17,6 +17,11 @@ PHI3_RESPONSE_START_WRAPPER = "" PHI3_RESPONSE_END_WRAPPER = "<|end|>" +GEMMA2_PROMPT_START_WRAPPER = "user\n" +GEMMA2_PROMPT_END_WRAPPER = "\nmodel\n" +GEMMA2_RESPONSE_START_WRAPPER = "" +GEMMA2_RESPONSE_END_WRAPPER = "" + def create_prompt( user_input: Text, system_preamble: Text, @@ -126,6 +131,11 @@ def get_model_special_tokens(model_name): prompt_end_wrapper = PHI3_PROMPT_END_WRAPPER response_start_wrapper = PHI3_RESPONSE_START_WRAPPER response_end_wrapper = PHI3_RESPONSE_END_WRAPPER + elif "gemma" in model_name.lower(): + prompt_start_wrapper = GEMMA2_PROMPT_START_WRAPPER + prompt_end_wrapper = GEMMA2_PROMPT_END_WRAPPER + response_start_wrapper = GEMMA2_RESPONSE_START_WRAPPER + response_end_wrapper = GEMMA2_RESPONSE_END_WRAPPER else: raise ValueError(f"Presets missing for prompting model {model_name}")